From 4ec876c9068e1238607be6096e1f3ca49e0b3e60 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 26 Mar 2026 14:32:40 -0700 Subject: [PATCH 1/5] feat: Add MLX simple.py entries from https://github.com/DBraun/sequence-layers/commit/80daa69bcb5a5580ff9fb73d13e416a1813b1462 Co-authored-by: David Braun <2096055+DBraun@users.noreply.github.com> --- sequence_layers/mlx/simple.py | 970 +++++++++++++++++++++++++++++ sequence_layers/mlx/simple_test.py | 508 +++++++++++++++ 2 files changed, 1478 insertions(+) create mode 100644 sequence_layers/mlx/simple.py create mode 100644 sequence_layers/mlx/simple_test.py diff --git a/sequence_layers/mlx/simple.py b/sequence_layers/mlx/simple.py new file mode 100644 index 0000000..830df0b --- /dev/null +++ b/sequence_layers/mlx/simple.py @@ -0,0 +1,970 @@ +"""Simple sequence layers for MLX.""" + +import dataclasses +import math + +from typing import Callable + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence + + +# --------------------------------------------------------------------------- +# Identity +# --------------------------------------------------------------------------- + + +class Identity(types.PreservesType, types.StatelessPointwise): + """Identity pass-through of the input.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + name: str | None = None + + def make(self) -> 'Identity': + return Identity.from_config(self) + + @types.check_layer + def layer(self, x, *, constants=None): + return x + + @classmethod + def from_config(cls, config): + return cls() + + +# --------------------------------------------------------------------------- +# Activation layers +# --------------------------------------------------------------------------- + + +class Relu(types.PreservesType, types.StatelessPointwiseFunctor): + """A Relu layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.relu(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Gelu(types.PreservesType, types.StatelessPointwiseFunctor): + """A Gelu layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.gelu(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Swish(types.PreservesType, types.StatelessPointwiseFunctor): + """A Swish/SiLU layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.silu(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Tanh(types.PreservesType, types.StatelessPointwiseFunctor): + """A tanh layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return mx.tanh(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Sigmoid(types.PreservesType, types.StatelessPointwiseFunctor): + """A sigmoid layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return mx.sigmoid(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class LeakyRelu(types.PreservesType, types.StatelessPointwiseFunctor): + """A Leaky Relu layer.""" + + def __init__(self, negative_slope=0.01): + super().__init__() + self._negative_slope = negative_slope + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.leaky_relu(values, self._negative_slope), mask + + @classmethod + def from_config(cls, config): + return cls(negative_slope=config.negative_slope) + + +class Elu(types.PreservesType, types.StatelessPointwiseFunctor): + """An ELU activation layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + alpha: complex = 1.0 + name: str | None = None + + def make(self) -> 'Elu': + return Elu.from_config(self) + + def __init__(self, alpha=1.0): + super().__init__() + self._alpha = alpha + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.elu(values, self._alpha), mask + + @classmethod + def from_config(cls, config): + return cls(alpha=config.alpha) + + +class Softmax(types.PreservesType, types.StatelessPointwiseFunctor): + """A softmax layer.""" + + def __init__(self, axis=-1): + super().__init__() + self._axis = axis + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + axis = self._axis + if (axis if axis >= 0 else values.ndim + axis) < 2: + raise ValueError( + 'The softmax cannot be applied on the batch or time' + f' dimension (got {axis=} for shape={values.shape})' + ) + return mx.softmax(values, axis=axis), mask + + @classmethod + def from_config(cls, config): + return cls(axis=config.axis) + + +class Softplus(types.PreservesType, types.StatelessPointwiseFunctor): + """A softplus layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.softplus(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +# --------------------------------------------------------------------------- +# Value manipulation +# --------------------------------------------------------------------------- + + +class Cast(types.StatelessPointwiseFunctor): + """Cast input values to the specified type.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + dtype: object = mx.float32 + name: str | None = None + + def make(self) -> 'Cast': + return Cast.from_config(self) + + def __init__(self, dtype): + super().__init__() + self._dtype = dtype + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return values.astype(self._dtype), mask + + def get_output_dtype(self, input_dtype, *, constants=None): + return self._dtype + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + return cls(dtype=_to_mx_dtype(config.dtype)) + + +class Scale(types.PreservesType, types.StatelessPointwise): + """Scales the input by a provided constant or array.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + scale: object = 1.0 + name: str | None = None + + def make(self) -> 'Scale': + return Scale.from_config(self) + + def __init__(self, scale): + super().__init__() + if isinstance(scale, (int, float, complex)): + self._scale = scale + else: + self._scale = mx.array(np.asarray(scale)) + + @types.check_layer + def layer(self, x, *, constants=None): + s = self._scale + if isinstance(s, mx.array): + s = s.astype(x.dtype) + return x.apply_values_masked(lambda v: v * s) + + @classmethod + def from_config(cls, config): + scale = config.scale + if hasattr(scale, 'data') and hasattr(scale, 'dtype'): + scale = np.array(scale.data, dtype=scale.dtype) + elif hasattr(scale, 'array'): + scale = np.asarray(scale.array) + return cls(scale=scale) + + +class Add(types.PreservesType, types.StatelessPointwise): + """Adds a provided constant or array to the input.""" + + def __init__(self, shift): + super().__init__() + if isinstance(shift, (int, float, complex)): + self._shift = shift + else: + self._shift = mx.array(np.asarray(shift)) + + @types.check_layer + def layer(self, x, *, constants=None): + s = self._shift + if isinstance(s, mx.array): + s = s.astype(x.dtype) + return x.apply_values(lambda v: v + s) + + @classmethod + def from_config(cls, config): + shift = config.shift + if hasattr(shift, 'data') and hasattr(shift, 'dtype'): + shift = np.array(shift.data, dtype=shift.dtype) + elif hasattr(shift, 'array'): + shift = np.asarray(shift.array) + return cls(shift=shift) + + +# --------------------------------------------------------------------------- +# Masking +# --------------------------------------------------------------------------- + + +class MaskInvalid(types.PreservesType, types.StatelessPointwise): + """Masks invalid timesteps to zero (or a specified value).""" + + def __init__(self, mask_value=None): + super().__init__() + self._mask_value = mask_value + + @types.check_layer + def layer(self, x, *, constants=None): + return x.mask_invalid(self._mask_value) + + @classmethod + def from_config(cls, config): + mask_value = getattr(config, 'mask_value', None) + return cls(mask_value=mask_value) + + +# --------------------------------------------------------------------------- +# Gated units +# --------------------------------------------------------------------------- + + +class GatedUnit(types.PreservesType, types.Stateless): + """Computes a generalized Gated Unit, reducing input channels by 2x.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + feature_activation: Callable | None = None + gate_activation: Callable | None = None + name: str | None = None + + def make(self) -> 'GatedUnit': + return GatedUnit.from_config(self) + + def __init__(self, feature_activation=None, gate_activation=None): + super().__init__() + self._feature_activation = feature_activation + self._gate_activation = gate_activation + + def get_output_shape(self, input_shape, *, constants=None): + channels = input_shape[-1] + if channels % 2 != 0: + raise ValueError( + f'Final dimension of input ({input_shape=}) must have' + ' an even number of channels.' + ) + return tuple(input_shape[:-1]) + (channels // 2,) + + @types.check_layer + def layer(self, x, *, constants=None): + feature, gate = mx.split(x.values, 2, axis=-1) + if self._feature_activation: + feature = self._feature_activation(feature) + if self._gate_activation: + gate = self._gate_activation(gate) + return Sequence(feature * gate, x.mask) + + @classmethod + def from_config(cls, config): + fa = init_mapping.map_activation(config.feature_activation) + ga = init_mapping.map_activation(config.gate_activation) + return cls(feature_activation=fa, gate_activation=ga) + + +class GatedLinearUnit(GatedUnit): + """Computes a Gated Linear Unit, reducing input channels by 2x.""" + + def __init__(self): + super().__init__( + feature_activation=None, + gate_activation=mx.sigmoid, + ) + + @classmethod + def from_config(cls, config): + return cls() + + +class GatedTanhUnit(GatedUnit): + """Computes a Gated Tanh Unit, reducing input channels by 2x.""" + + def __init__(self): + super().__init__( + feature_activation=mx.tanh, + gate_activation=mx.sigmoid, + ) + + @classmethod + def from_config(cls, config): + return cls() + + +# --------------------------------------------------------------------------- +# Shape manipulation +# --------------------------------------------------------------------------- + + +class Flatten(types.PreservesType, types.StatelessPointwise): + """Flattens the channel dimensions of the input sequence.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + name: str | None = None + + def make(self) -> 'Flatten': + return Flatten.from_config(self) + + def get_output_shape(self, input_shape, *, constants=None): + return (math.prod(input_shape),) + + @types.check_layer + def layer(self, x, *, constants=None): + batch_size, time = x.values.shape[:2] + num_elements = math.prod(x.channel_shape) + new_values = mx.reshape(x.values, (batch_size, time, num_elements)) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls() + + +class Reshape(types.PreservesType, types.Stateless): + """Reshapes the channels dimension of the input.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + output_shape: tuple[int, ...] = () + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + + def make(self) -> 'Reshape': + return Reshape.from_config(self) + + def __init__(self, output_shape): + super().__init__() + self._output_shape = tuple(output_shape) + + def _validate(self, input_shape): + in_elems = math.prod(input_shape) + out_elems = math.prod(self._output_shape) + if in_elems != out_elems: + raise ValueError( + f'Reshape output_shape={self._output_shape} must have' + f' the same number of elements as {input_shape=}.' + ) + + def get_output_shape(self, input_shape, *, constants=None): + self._validate(input_shape) + return self._output_shape + + @types.check_layer + def layer(self, x, *, constants=None): + self._validate(x.channel_shape) + b, t = x.values.shape[:2] + new_values = mx.reshape(x.values, (b, t) + self._output_shape) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(output_shape=config.output_shape) + + +class ExpandDims(types.PreservesType, types.Stateless): + """Expands channel dimensions of the input sequence.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + axis: int | tuple[int, ...] = 0 + name: str | None = None + + def __post_init__(self): + if not isinstance(self.axis, int): + object.__setattr__(self, 'axis', tuple(self.axis)) + + def make(self) -> 'ExpandDims': + return ExpandDims.from_config(self) + + def __init__(self, axis): + super().__init__() + if isinstance(axis, int): + self._axis = (axis,) + else: + self._axis = tuple(axis) + + def _normalize_axes(self, input_shape): + rank = len(input_shape) + dims = sorted(set(a + rank + 1 if a < 0 else a for a in self._axis)) + for d in dims: + if d < 0 or d > rank: + raise ValueError( + f'ExpandDims axes must refer to channel dims. Got: {self._axis}.' + ) + return dims + + def get_output_shape(self, input_shape, *, constants=None): + dims = self._normalize_axes(input_shape) + out = list(input_shape) + for a in dims: + out.insert(a, 1) + return tuple(out) + + @types.check_layer + def layer(self, x, *, constants=None): + dims = [2 + d for d in self._normalize_axes(x.channel_shape)] + new_values = mx.expand_dims(x.values, axis=dims) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(axis=config.axis) + + +class Squeeze(types.PreservesType, types.Stateless): + """Squeezes singleton channel dimensions of the input.""" + + def __init__(self, axis=None): + super().__init__() + self._axis = axis + + def _channel_squeeze_axes(self, input_shape): + """Return channel-relative axes to squeeze.""" + if self._axis is None: + # Squeeze all singleton channel dims. + return tuple(i for i, n in enumerate(input_shape) if n == 1) + # If axis is given, it's in full-tensor coords. Convert to channel. + if isinstance(self._axis, int): + axes = (self._axis,) + else: + axes = tuple(self._axis) + return axes + + def get_output_shape(self, input_shape, *, constants=None): + squeeze_axes = self._channel_squeeze_axes(input_shape) + out = [] + for i, s in enumerate(input_shape): + if i not in squeeze_axes: + out.append(s) + return tuple(out) if out else (1,) + + @types.check_layer + def layer(self, x, *, constants=None): + ch_axes = self._channel_squeeze_axes(x.channel_shape) + # Convert to full-tensor axes (offset by 2 for batch, time). + full_axes = tuple(a + 2 for a in ch_axes) + new_values = mx.squeeze(x.values, axis=full_axes) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(axis=config.axis) + + +class Transpose(types.PreservesType, types.Stateless): + """Permutes the channel axes of the input.""" + + def __init__(self, axes=None): + super().__init__() + if axes is not None: + axes = tuple(axes) + self._axes = axes + + def _resolve_axes(self, input_shape): + input_axes = tuple(range(2, 2 + len(input_shape))) + if self._axes is None: + return input_axes[::-1] + sorted_axes = tuple(sorted(self._axes)) + if sorted_axes != input_axes: + raise ValueError( + f'Provided axes {sorted_axes} do not match input axes {input_axes}.' + ) + return tuple(self._axes) + + def get_output_shape(self, input_shape, *, constants=None): + axes = self._resolve_axes(input_shape) + return tuple(input_shape[a - 2] for a in axes) + + @types.check_layer + def layer(self, x, *, constants=None): + axes = self._resolve_axes(x.channel_shape) + new_values = mx.transpose(x.values, (0, 1) + axes) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(axes=config.axes) + + +# --------------------------------------------------------------------------- +# Encoding +# --------------------------------------------------------------------------- + + +class OneHot(types.Stateless): + """Computes one-hot vector of the input.""" + + def __init__(self, depth, compute_dtype=mx.float32): + super().__init__() + self._depth = depth + self._compute_dtype = compute_dtype + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + (self._depth,) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self._compute_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + def one_hot_fn(v): + indices = v.astype(mx.int32) + return mx.eye(self._depth, dtype=self._compute_dtype)[indices] + + return x.apply_values(one_hot_fn) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + return cls( + depth=config.depth, + compute_dtype=_to_mx_dtype(config.compute_dtype), + ) + + +class Embedding(types.Stateless): + """Computes embeddings of integer input codes. + + Backed by mlx.nn.Embedding. + """ + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + num_embeddings: int = 1 + dimension: int = 1 + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def make(self) -> 'Embedding': + return Embedding.from_config(self) + + def __init__( + self, + *, + num_embeddings: int, + dimension: int, + param_dtype=mx.float32, + compute_dtype=None, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.dimension = dimension + self._param_dtype = param_dtype + self.compute_dtype = compute_dtype + self._embedding = nn.Embedding(num_embeddings, dimension) + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + (self.dimension,) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + def embed_fn(v): + result = self._embedding(v.astype(mx.int32)) + if self.compute_dtype is not None: + result = result.astype(self.compute_dtype) + return result + + return x.apply_values(embed_fn) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + return cls( + num_embeddings=config.num_embeddings, + dimension=config.dimension, + param_dtype=_to_mx_dtype(config.param_dtype), + compute_dtype=compute_dtype, + ) + + +# --------------------------------------------------------------------------- +# Regularization +# --------------------------------------------------------------------------- + + +class Dropout(types.PreservesType, types.StatelessPointwise): + """Dropout layer (pass-through during inference).""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + rate: float = 0.0 + broadcast_dims: tuple[int, ...] = () + name: str | None = None + + def make(self) -> 'Dropout': + return Dropout.from_config(self) + + def __init__(self, rate=0.0): + super().__init__() + self._rate = rate + + @types.check_layer + def layer(self, x, *, constants=None): + # Inference-only: dropout is a no-op. + return x + + @classmethod + def from_config(cls, config): + return cls(rate=config.rate) + + +# --------------------------------------------------------------------------- +# Sampling +# --------------------------------------------------------------------------- + + +class Downsample1D(types.PreservesType, types.Stateless): + """A 1D downsampling layer.""" + + def __init__(self, rate): + super().__init__() + self._rate = rate + + @property + def block_size(self): + return self._rate + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + + @types.check_layer + def layer(self, x, *, constants=None): + new_values = x.values[:, :: self._rate] + new_mask = x.mask[:, :: self._rate] + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, new_mask) + return Sequence(new_values, new_mask) + + @classmethod + def from_config(cls, config): + return cls(rate=config.rate) + + +class Upsample1D(types.PreservesType, types.Stateless): + """A 1D upsampling layer.""" + + def __init__(self, rate): + super().__init__() + self._rate = rate + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + + @types.check_layer + def layer(self, x, *, constants=None): + # Repeat each timestep `rate` times along the time axis. + b, t = x.values.shape[:2] + channel_shape = x.values.shape[2:] + # [b, t, 1, ...] -> [b, t, rate, ...] -> [b, t*rate, ...] + expanded = mx.expand_dims(x.values, axis=2) + tiled = mx.repeat(expanded, self._rate, axis=2) + new_values = mx.reshape(tiled, (b, t * self._rate) + channel_shape) + # Same for mask: [b, t] -> [b, t*rate] + new_mask = mx.repeat(mx.expand_dims(x.mask, axis=2), self._rate, axis=2) + new_mask = mx.reshape(new_mask, (b, t * self._rate)) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, new_mask) + return Sequence(new_values, new_mask) + + @classmethod + def from_config(cls, config): + return cls(rate=config.rate) + + +# --------------------------------------------------------------------------- +# CheckpointName (identity for inference) +# --------------------------------------------------------------------------- + + +class CheckpointName(types.PreservesType, types.StatelessPointwiseFunctor): + """Identity pass-through (checkpoint naming is JAX-only).""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + checkpoint_name: str = '' + name: str | None = None + + def make(self) -> 'CheckpointName': + return CheckpointName.from_config(self) + + def __init__(self, checkpoint_name=''): + super().__init__() + self._checkpoint_name = checkpoint_name + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return values, mask + + @classmethod + def from_config(cls, config): + return cls(checkpoint_name=config.checkpoint_name) + + +# --------------------------------------------------------------------------- +# Lambda +# --------------------------------------------------------------------------- + + +class Lambda(types.Stateless): + """A SequenceLayer that wraps a Python callable.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + fn: Callable = None + sequence_input: bool = False + mask_required: bool = True + # Accepted for JAX compatibility but ignored by MLX Lambda. + expected_input_spec: object = None + expected_output_spec: object = None + name: str | None = None + + def make(self) -> 'Lambda': + return Lambda.from_config(self) + + def __init__(self, fn, *, sequence_input=False, mask_required=True, + expected_output_spec=None): + super().__init__() + self._fn = fn + self._sequence_input = sequence_input + self._mask_required = mask_required + self._expected_output_spec = expected_output_spec + self._cached_output_spec = None + + def _probe_output(self, input_shape, input_dtype): + """Probe the function with a dummy to infer output shape/dtype.""" + if self._expected_output_spec is not None: + return self._expected_output_spec + if self._cached_output_spec is not None: + return self._cached_output_spec + try: + dummy_values = mx.zeros((1, 1) + tuple(input_shape), dtype=input_dtype) + dummy_mask = mx.ones((1, 1), dtype=mx.bool_) + if self._sequence_input: + result = self._fn(Sequence(dummy_values, dummy_mask)) + out_shape = result.values.shape[2:] + out_dtype = result.values.dtype + else: + out_values = self._fn(dummy_values) + out_shape = out_values.shape[2:] + out_dtype = out_values.dtype + self._cached_output_spec = bt.ShapeDType(out_shape, out_dtype) + return self._cached_output_spec + except Exception: + return None + + def get_output_shape(self, input_shape, *, constants=None): + spec = self._probe_output(input_shape, mx.float32) + if spec is not None: + return tuple(spec.shape) + return tuple(input_shape) + + def get_output_dtype(self, input_dtype, *, constants=None): + spec = self._probe_output((1,), input_dtype) + if spec is not None: + return spec.dtype + return input_dtype + + def layer(self, x, *, constants=None): + if self._sequence_input: + result = self._fn(x) + if not isinstance(result, (Sequence, MaskedSequence)): + raise ValueError( + 'Lambda with sequence_input=True must return a Sequence, ' + f'got {type(result)}' + ) + return result + else: + new_values = self._fn(x.values) + if self._mask_required or not isinstance(x, MaskedSequence): + return Sequence(new_values, x.mask) + return MaskedSequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls( + fn=config.fn, + sequence_input=config.sequence_input, + mask_required=config.mask_required, + expected_output_spec=getattr(config, 'expected_output_spec', None), + ) + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +class Logging(types.PreservesType, types.StatelessPointwise): + """Logs input info and returns the input unchanged.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + prefix: str = '' + dump_tensors: bool = False + name: str | None = None + + def make(self) -> 'Logging': + return Logging.from_config(self) + + def __init__(self, prefix='', dump_tensors=False): + super().__init__() + self._prefix = prefix + self._dump_tensors = dump_tensors + + @types.check_layer + def layer(self, x, *, constants=None): + if self._dump_tensors: + print(f'{self._prefix} layer(): x={x.values}') + else: + print( + f'{self._prefix} layer(): x.shape={x.shape}, ' + f'x.dtype={x.dtype}' + ) + return x + + @classmethod + def from_config(cls, config): + return cls( + prefix=config.prefix, + dump_tensors=config.dump_tensors, + ) diff --git a/sequence_layers/mlx/simple_test.py b/sequence_layers/mlx/simple_test.py new file mode 100644 index 0000000..08b19da --- /dev/null +++ b/sequence_layers/mlx/simple_test.py @@ -0,0 +1,508 @@ +"""Tests for simple MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import simple +from sequence_layers.mlx import test_utils + + +class IdentityTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Identity() + test_utils.verify_contract(self, layer, (4,)) + + def test_preserves_values(self): + layer = simple.Identity() + x = test_utils.random_sequence(2, 3, 4) + y = layer.layer(x) + np.testing.assert_array_equal(y.values, x.values) + np.testing.assert_array_equal(y.mask, x.mask) + + +class ReluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Relu() + test_utils.verify_contract(self, layer, (4,)) + + def test_negative_zeroed(self): + layer = simple.Relu() + values = mx.array([[-1.0, 0.5, -0.3, 2.0]]).reshape(1, 1, 4) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[0.0, 0.5, 0.0, 2.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class GeluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Gelu() + test_utils.verify_contract(self, layer, (4,)) + + +class SwishTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Swish() + test_utils.verify_contract(self, layer, (4,)) + + +class TanhTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Tanh() + test_utils.verify_contract(self, layer, (4,)) + + def test_values(self): + layer = simple.Tanh() + values = mx.array([[[0.0, 1.0, -1.0, 100.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + np.testing.assert_allclose( + y.values, np.tanh([[[0.0, 1.0, -1.0, 100.0]]]), atol=1e-5 + ) + + +class SigmoidTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Sigmoid() + test_utils.verify_contract(self, layer, (4,)) + + +class LeakyReluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.LeakyRelu(negative_slope=0.2) + test_utils.verify_contract(self, layer, (4,)) + + def test_negative_slope(self): + layer = simple.LeakyRelu(negative_slope=0.1) + values = mx.array([[[-2.0, 0.5, -1.0, 3.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[-0.2, 0.5, -0.1, 3.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class EluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Elu() + test_utils.verify_contract(self, layer, (4,)) + + +class SoftmaxTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Softmax() + test_utils.verify_contract(self, layer, (4,)) + + def test_sums_to_one(self): + layer = simple.Softmax(axis=-1) + values = mx.array([[[1.0, 2.0, 3.0, 4.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + np.testing.assert_allclose(float(mx.sum(y.values)), 1.0, atol=1e-5) + + +class SoftplusTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Softplus() + test_utils.verify_contract(self, layer, (4,)) + + +class CastTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Cast(dtype=mx.float16) + test_utils.verify_contract(self, layer, (4,), atol=1e-3, rtol=1e-3) + + def test_cast(self): + layer = simple.Cast(dtype=mx.float16) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + self.assertEqual(y.dtype, mx.float16) + + +class ScaleTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Scale(scale=2.0) + test_utils.verify_contract(self, layer, (4,)) + + def test_scalar(self): + layer = simple.Scale(scale=2.0) + values = mx.array([[[1.0, 2.0, 3.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[2.0, 4.0, 6.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class AddTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Add(shift=1.0) + test_utils.verify_contract(self, layer, (4,)) + + def test_scalar(self): + layer = simple.Add(shift=10.0) + values = mx.array([[[1.0, 2.0, 3.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[11.0, 12.0, 13.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class MaskInvalidTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.MaskInvalid() + test_utils.verify_contract(self, layer, (4,)) + + def test_masks_to_zero(self): + layer = simple.MaskInvalid() + values = mx.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) + mask = mx.array([[True, False, True]]) + x = bt.Sequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[1.0, 2.0], [0.0, 0.0], [5.0, 6.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class GatedUnitTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.GatedUnit() + test_utils.verify_contract(self, layer, (8,)) + + def test_with_activations(self): + import mlx.nn as nn + + layer = simple.GatedUnit( + feature_activation=nn.relu, gate_activation=nn.sigmoid + ) + test_utils.verify_contract(self, layer, (8,)) + + +class GatedLinearUnitTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.GatedLinearUnit() + test_utils.verify_contract(self, layer, (8,)) + + def test_halves_channels(self): + layer = simple.GatedLinearUnit() + self.assertEqual(layer.get_output_shape((8,)), (4,)) + + +class GatedTanhUnitTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.GatedTanhUnit() + test_utils.verify_contract(self, layer, (8,)) + + +class FlattenTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Flatten() + test_utils.verify_contract(self, layer, (2, 3, 4)) + + def test_flatten(self): + layer = simple.Flatten() + self.assertEqual(layer.get_output_shape((2, 3, 4)), (24,)) + + +class ReshapeTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Reshape(output_shape=(2, 6)) + test_utils.verify_contract(self, layer, (12,)) + + def test_reshape(self): + layer = simple.Reshape(output_shape=(2, 6)) + x = test_utils.random_sequence(1, 3, 12) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (2, 6)) + + def test_mismatch_raises(self): + layer = simple.Reshape(output_shape=(5,)) + with self.assertRaises(ValueError): + layer.get_output_shape((12,)) + + +class ExpandDimsTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.ExpandDims(axis=-1) + test_utils.verify_contract(self, layer, (4,)) + + def test_expand(self): + layer = simple.ExpandDims(axis=0) + self.assertEqual(layer.get_output_shape((4, 8)), (1, 4, 8)) + + def test_layer_values(self): + layer = simple.ExpandDims(axis=-1) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (4, 1)) + + +class SqueezeTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Squeeze() + test_utils.verify_contract(self, layer, (4, 1)) + + def test_squeeze(self): + layer = simple.Squeeze() + x = bt.MaskedSequence( + mx.ones((1, 3, 1, 4, 1)), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (4,)) + + +class TransposeTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Transpose() + test_utils.verify_contract(self, layer, (2, 3, 4)) + + def test_reverse(self): + layer = simple.Transpose() + self.assertEqual(layer.get_output_shape((2, 3, 4)), (4, 3, 2)) + + def test_explicit(self): + layer = simple.Transpose(axes=(3, 2, 4)) + self.assertEqual(layer.get_output_shape((5, 6, 7)), (6, 5, 7)) + + +class OneHotTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.OneHot(depth=5) + x = bt.MaskedSequence( + mx.array([[0, 2, 4]]), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 3, 5)) + # Check that index 0 -> [1,0,0,0,0] + np.testing.assert_allclose(np.array(y.values[0, 0]), [1, 0, 0, 0, 0]) + + +class EmbeddingTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Embedding(num_embeddings=10, dimension=8) + x = bt.MaskedSequence( + mx.array([[1, 3, 5]]), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 3, 8)) + + def test_output_shape(self): + layer = simple.Embedding(num_embeddings=10, dimension=8) + self.assertEqual(layer.get_output_shape(()), (8,)) + self.assertEqual(layer.get_output_shape((3,)), (3, 8)) + + +class DropoutTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Dropout(rate=0.5) + test_utils.verify_contract(self, layer, (4,)) + + def test_passthrough(self): + layer = simple.Dropout(rate=0.5) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + # Inference-only: should be identity. + np.testing.assert_array_equal(y.values, x.values) + + +class Downsample1DTest(parameterized.TestCase): + + def test_verify_contract(self): + layer = simple.Downsample1D(rate=2) + test_utils.verify_contract(self, layer, (4,)) + + def test_layer(self): + layer = simple.Downsample1D(rate=2) + x = test_utils.random_sequence(1, 6, 4) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_values(self): + layer = simple.Downsample1D(rate=3) + values = mx.arange(12).reshape(1, 6, 2).astype(mx.float32) + mask = mx.ones((1, 6), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + # Should keep timesteps 0, 3. + np.testing.assert_array_equal(y.values, values[:, ::3]) + + +class Upsample1DTest(parameterized.TestCase): + + def test_verify_contract(self): + layer = simple.Upsample1D(rate=3) + test_utils.verify_contract(self, layer, (4,)) + + def test_layer(self): + layer = simple.Upsample1D(rate=3) + x = test_utils.random_sequence(1, 4, 2) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 12, 2)) + + def test_values(self): + layer = simple.Upsample1D(rate=2) + values = mx.array([[[1.0, 2.0], [3.0, 4.0]]]) + mask = mx.ones((1, 2), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]]) + np.testing.assert_allclose(y.values, expected) + self.assertEqual(y.mask.shape, (1, 4)) + + +class BackendDispatchTest(parameterized.TestCase): + """Test config.make(backend='mlx') for simple layers.""" + + def test_identity(self): + import sequence_layers.mlx # Register backends. + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Identity.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Identity) + + def test_relu(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Relu.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Relu) + + def test_tanh(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Tanh.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Tanh) + + def test_gated_linear_unit(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.GatedLinearUnit.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.GatedLinearUnit) + + def test_reshape(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Reshape.Config(output_shape=(2, 3)) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Reshape) + + def test_downsample(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Downsample1D.Config(rate=2) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Downsample1D) + + +class CheckpointNameTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.CheckpointName(checkpoint_name='test') + test_utils.verify_contract(self, layer, (4,)) + + def test_passthrough(self): + layer = simple.CheckpointName(checkpoint_name='test') + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_array_equal(y.values, x.values) + np.testing.assert_array_equal(y.mask, x.mask) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.CheckpointName.Config(checkpoint_name='test') + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.CheckpointName) + + +class LambdaTest(parameterized.TestCase): + + def test_values_fn(self): + layer = simple.Lambda(fn=lambda v: v * 2.0) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_allclose(y.values, x.values * 2.0, atol=1e-6) + + def test_sequence_fn(self): + def double_seq(s): + return bt.Sequence(s.values * 2.0, s.mask) + + layer = simple.Lambda(fn=double_seq, sequence_input=True) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_allclose(y.values, x.values * 2.0, atol=1e-6) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Lambda.Config(fn=lambda v: v) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Lambda) + + +class LoggingTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Logging(prefix='test') + test_utils.verify_contract(self, layer, (4,)) + + def test_passthrough(self): + layer = simple.Logging() + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_array_equal(y.values, x.values) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Logging.Config(prefix='test') + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Logging) + + +if __name__ == '__main__': + absltest.main() From 23bde687f0ee52b7aad0f605230841ea83d4a783 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 16 Apr 2026 13:36:19 -0700 Subject: [PATCH 2/5] refactor(specs): Abstract simple layers into spec and extract shared behaviors --- pyproject.toml | 2 + sequence_layers/specs/__init__.py | 69 +- sequence_layers/specs/backend.py | 47 +- sequence_layers/specs/backend_behaviors.py | 56 ++ sequence_layers/specs/simple.py | 631 ++++++++++++++++ sequence_layers/specs/simple_behaviors.py | 796 +++++++++++++++++++++ sequence_layers/specs/test_utils.py | 47 +- sequence_layers/specs/types.py | 75 +- sequence_layers/specs/types_behaviors.py | 82 ++- 9 files changed, 1782 insertions(+), 23 deletions(-) create mode 100644 sequence_layers/specs/simple.py create mode 100644 sequence_layers/specs/simple_behaviors.py diff --git a/pyproject.toml b/pyproject.toml index 3b9626e..77991dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,8 @@ disable = [ "too-many-statements", "too-many-branches", "too-many-positional-arguments", + "too-many-public-methods", + "too-many-return-statements", ] diff --git a/sequence_layers/specs/__init__.py b/sequence_layers/specs/__init__.py index d523988..10df687 100644 --- a/sequence_layers/specs/__init__.py +++ b/sequence_layers/specs/__init__.py @@ -1,8 +1,11 @@ +"""Specifications for sequence layers.""" + # https://typing.python.org/en/latest/spec/protocol.html#modules-as-implementations-of-protocols from typing import Protocol, runtime_checkable, TYPE_CHECKING from . import backend as _backend +from . import simple as _simple from . import types as _types # Import test_utils only for type checking to avoid circular imports, @@ -11,6 +14,8 @@ from . import test_utils as _test_utils +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring @runtime_checkable class ModuleSpec(Protocol): """Protocol for a backend-specific SequenceLayers module (sequence_layers. as sl).""" @@ -23,12 +28,17 @@ def backend(self) -> _backend.ModuleSpec: def types(self) -> _types.ModuleSpec: ... + @property + def simple(self) -> _simple.ModuleSpec: + ... + @property def test_utils(self) -> '_test_utils.ModuleSpec': ... # Identifiers that backend-specific implementations should expose at top level. - # Demonstrating read-only allows for covariance (subclasses of types_module.Sequence to satisfy the protocol). + # Demonstrating read-only allows for covariance (subclasses of + # types_module.Sequence to satisfy the protocol). @property def Sequence(self) -> type[_types.Sequence]: @@ -49,3 +59,60 @@ def SequenceLayerConfig(self) -> type[_types.SequenceLayerConfig]: @property def SequenceLayerTest(self) -> type: ... + + # Privileged layers appearing top-level + @property + def Flatten(self) -> type[_simple.Flatten]: + ... + + @property + def Reshape(self) -> type[_simple.Reshape]: + ... + + @property + def ExpandDims(self) -> type[_simple.ExpandDims]: + ... + + @property + def Squeeze(self) -> type[_simple.Squeeze]: + ... + + @property + def Scale(self) -> type[_simple.Scale]: + ... + + @property + def Add(self) -> type[_simple.Add]: + ... + + @property + def Cast(self) -> type[_simple.Cast]: + ... + + @property + def MaskInvalid(self) -> type[_simple.MaskInvalid]: + ... + + @property + def GatedUnit(self) -> type[_simple.GatedUnit]: + ... + + @property + def GatedLinearUnit(self) -> type[_simple.GatedLinearUnit]: + ... + + @property + def GatedTanhUnit(self) -> type[_simple.GatedTanhUnit]: + ... + + @property + def OneHot(self) -> type[_simple.OneHot]: + ... + + @property + def Embedding(self) -> type[_simple.Embedding]: + ... + + @property + def Softmax(self) -> type[_simple.Softmax]: + ... diff --git a/sequence_layers/specs/backend.py b/sequence_layers/specs/backend.py index d65edfe..64a819c 100644 --- a/sequence_layers/specs/backend.py +++ b/sequence_layers/specs/backend.py @@ -26,16 +26,59 @@ def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Array: """Creates an array of zeros.""" def concatenate(self, arrays: list[Array], axis: int = 0) -> Array: - ... + """Concatenates arrays.""" + + def abs(self, x: Array) -> Array: + """Computes absolute value.""" + + def exp(self, x: Array) -> Array: + """Computes exponential.""" + + def log(self, x: Array) -> Array: + """Computes natural logarithm.""" + + +class nn(Protocol): + """Protocol for neural network operations (activations).""" + + def relu(self, x: Array) -> Array: + """Computes ReLU activation.""" + + def sigmoid(self, x: Array) -> Array: + """Computes sigmoid activation.""" + + def tanh(self, x: Array) -> Array: + """Computes tanh activation.""" + + def swish(self, x: Array) -> Array: + """Computes swish activation.""" + def gelu(self, x: Array) -> Array: + """Computes GeLU activation.""" + def elu(self, x: Array) -> Array: + """Computes ELU activation.""" + + def softplus(self, x: Array) -> Array: + """Computes softplus activation.""" + + def softmax(self, x: Array, axis: int = -1) -> Array: + """Computes softmax activation.""" + + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring @runtime_checkable class ModuleSpec(Protocol): """Specification for sequence_layers..backend""" @property def xp(self) -> xp: - """Returns the NumPy-compatible interface.""" + ... + + @property + def nn(self) -> nn: + ... __all__ = [ diff --git a/sequence_layers/specs/backend_behaviors.py b/sequence_layers/specs/backend_behaviors.py index d72826c..6cda77b 100644 --- a/sequence_layers/specs/backend_behaviors.py +++ b/sequence_layers/specs/backend_behaviors.py @@ -4,6 +4,8 @@ from typing import override +import numpy as np + from sequence_layers import specs from sequence_layers.specs import backend as backend_spec from sequence_layers.specs import test_utils as test_utils_spec @@ -14,3 +16,57 @@ class ModuleSpecTest(test_utils_spec.ModuleSpecTest): @override def module_spec_pairs(self, backend_sl: specs.ModuleSpec): return {backend_sl.backend: backend_spec.ModuleSpec} + + +class BackendNNTest(test_utils_spec.SequenceLayerTest): + """Test behavior of backend.nn operations.""" + + def test_relu(self): + x = self.xp.array(np.array([[-1.0, 0.0, 1.0]], dtype=np.float32)) + y = self.nn.relu(x) + expected = self.xp.array(np.array([[0.0, 0.0, 1.0]], dtype=np.float32)) + self.assertAllEqual(y, expected) + + def test_sigmoid(self): + x = self.xp.array(np.array([[0.0]], dtype=np.float32)) + y = self.nn.sigmoid(x) + expected = self.xp.array(np.array([[0.5]], dtype=np.float32)) + self.assertAllEqual(y, expected) + + def test_tanh(self): + x = self.xp.array(np.array([[0.0]], dtype=np.float32)) + y = self.nn.tanh(x) + expected = self.xp.array(np.array([[0.0]], dtype=np.float32)) + self.assertAllEqual(y, expected) + + def test_elu(self): + x = self.xp.array(np.array([[0.0]], dtype=np.float32)) + y = self.nn.elu(x) + expected = self.xp.array(np.array([[0.0]], dtype=np.float32)) + self.assertAllEqual(y, expected) + + def test_softplus(self): + x = self.xp.array(np.array([[0.0]], dtype=np.float32)) + y = self.nn.softplus(x) + expected = self.xp.array(np.array([[np.log(2.0)]], dtype=np.float32)) + + # Wrap in Sequence to satisfy assertSequencesClose in JAX + y_seq = self.sl.types.Sequence.from_values(y) + expected_seq = self.sl.types.Sequence.from_values(expected) + + if hasattr(self, 'assertSequencesClose'): + self.assertSequencesClose(y_seq, expected_seq) + else: + self.assertAllEqual(y, expected) + + def test_swish(self): + x = self.xp.array(np.array([[0.0]], dtype=np.float32)) + y = self.nn.swish(x) + expected = self.xp.array(np.array([[0.0]], dtype=np.float32)) + self.assertAllEqual(y, expected) + + def test_gelu(self): + x = self.xp.array(np.array([[0.0]], dtype=np.float32)) + y = self.nn.gelu(x) + expected = self.xp.array(np.array([[0.0]], dtype=np.float32)) + self.assertAllEqual(y, expected) diff --git a/sequence_layers/specs/simple.py b/sequence_layers/specs/simple.py new file mode 100644 index 0000000..5d6faad --- /dev/null +++ b/sequence_layers/specs/simple.py @@ -0,0 +1,631 @@ +"""Specifications for simple layers. + +See the corresponding _behaviors module for behaviors. +""" + +# pylint: disable=abstract-method + +import abc +import dataclasses +from typing import (Any, Callable, Generic, Protocol, runtime_checkable, + Sequence, TypeVar) + +from sequence_layers.specs import types as types_spec +from sequence_layers.specs.types import HashableArray + +# --------------------------------------------------------------------------- +# Activation Functions (StatelessPointwiseFunctor) +# --------------------------------------------------------------------------- + + +class Identity[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwise[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Identity layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Identity layer.""" + + +class Relu[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Relu layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Relu layer.""" + + +class Gelu[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Gelu layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Gelu layer.""" + + +class Abs[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Abs layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Abs layer.""" + + +class Exp[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Exp layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Exp layer.""" + + +class Log[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Log layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Log layer.""" + + +class Swish[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Swish layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Swish layer.""" + + +class Tanh[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Tanh layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Tanh layer.""" + + +class Sigmoid[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Sigmoid layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Sigmoid layer.""" + + +class LeakyRelu[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for LeakyRelu layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for LeakyRelu layer.""" + + +class Elu[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Elu layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Elu layer.""" + + +class Softmax[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Softmax layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Softmax layer.""" + + +class Softplus[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Softplus layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Softplus layer.""" + + +# --------------------------------------------------------------------------- +# Simple Math and Pointwise (StatelessPointwise) +# --------------------------------------------------------------------------- + + +class Cast[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Cast layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Cast layer.""" + + +class Scale[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwise[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Scale layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Scale layer.""" + + +class Add[SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwise[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Add layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Add layer.""" + + +class MaskInvalid[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwise[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for MaskInvalid layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for MaskInvalid layer.""" + + +# --------------------------------------------------------------------------- +# Gating (Stateless) +# --------------------------------------------------------------------------- + + +T = TypeVar('T') + + +class GatedUnit[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for GatedUnit layer.""" + + @dataclasses.dataclass(frozen=True) + class Config[T](types_spec.SequenceLayerConfig): + """Configuration for GatedUnit layer.""" + + feature_activation: Callable[[T], T] | None + gate_activation: Callable[[T], T] | None + + +class GatedLinearUnit[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +](GatedUnit[SequenceT, ShapeDTypeT], metaclass=abc.ABCMeta): + """Specification for GatedLinearUnit layer.""" + + @dataclasses.dataclass(frozen=True) + class Config[T](GatedUnit.Config[T]): + """Configuration for GatedLinearUnit layer.""" + + +class GatedTanhUnit[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +](GatedUnit[SequenceT, ShapeDTypeT], metaclass=abc.ABCMeta): + """Specification for GatedTanhUnit layer.""" + + @dataclasses.dataclass(frozen=True) + class Config[T](GatedUnit.Config[T]): + """Configuration for GatedTanhUnit layer.""" + + +# --------------------------------------------------------------------------- +# Shape Operations (Stateless) +# --------------------------------------------------------------------------- + + +class Flatten[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Flatten layer.""" + + +class Reshape[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Reshape layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Reshape layer.""" + + output_shape: Sequence[int] + + +class ExpandDims[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for ExpandDims layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for ExpandDims layer.""" + + axis: int | Sequence[int] + + +class Squeeze[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Squeeze layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Squeeze layer.""" + + axis: int | Sequence[int] | None + + +class Transpose[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Transpose layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Transpose layer.""" + + axes: Sequence[int] | None + + +# --------------------------------------------------------------------------- +# Other Simple Layers +# --------------------------------------------------------------------------- + + +class OneHot[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for OneHot layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for OneHot layer.""" + + depth: int + + +class Embedding[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Embedding layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Embedding layer.""" + + dimension: int + num_embeddings: int + + +class Dropout[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwise[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Dropout layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Dropout layer.""" + + rate: float + + +class Downsample1D[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Downsample1D layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Downsample1D layer.""" + + rate: int + + +class Upsample1D[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Upsample1D layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Upsample1D layer.""" + + rate: int + + +class CheckpointName[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwiseFunctor[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for CheckpointName layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for CheckpointName layer.""" + + checkpoint_name: str + + +class Lambda[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.Stateless[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Lambda layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Lambda layer.""" + + fn: Callable[..., Any] + + +class Logging[ + SequenceT: types_spec.Sequence, ShapeDTypeT: types_spec.ChannelSpec +]( + types_spec.PreservesType[SequenceT, SequenceT, ShapeDTypeT], + types_spec.StatelessPointwise[SequenceT, SequenceT, ShapeDTypeT], + metaclass=abc.ABCMeta, +): + """Specification for Logging layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Logging layer.""" + + prefix: str + + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring +@runtime_checkable +class ModuleSpec(Protocol): + """Protocol for simple layers module.""" + + @property + def Identity(self) -> type[Identity]: + ... + + @property + def Relu(self) -> type[Relu]: + ... + + @property + def Gelu(self) -> type[Gelu]: + ... + + @property + def Swish(self) -> type[Swish]: + ... + + @property + def Tanh(self) -> type[Tanh]: + ... + + @property + def Sigmoid(self) -> type[Sigmoid]: + ... + + @property + def LeakyRelu(self) -> type[LeakyRelu]: + ... + + @property + def Elu(self) -> type[Elu]: + ... + + @property + def Softmax(self) -> type[Softmax]: + ... + + @property + def Softplus(self) -> type[Softplus]: + ... + + @property + def Cast(self) -> type[Cast]: + ... + + @property + def Scale(self) -> type[Scale]: + ... + + @property + def Add(self) -> type[Add]: + ... + + @property + def MaskInvalid(self) -> type[MaskInvalid]: + ... + + @property + def GatedUnit(self) -> type[GatedUnit]: + ... + + @property + def GatedLinearUnit(self) -> type[GatedLinearUnit]: + ... + + @property + def GatedTanhUnit(self) -> type[GatedTanhUnit]: + ... + + @property + def Flatten(self) -> type[Flatten]: + ... + + @property + def Reshape(self) -> type[Reshape]: + ... + + @property + def ExpandDims(self) -> type[ExpandDims]: + ... + + @property + def Squeeze(self) -> type[Squeeze]: + ... + + @property + def Transpose(self) -> type[Transpose]: + ... + + @property + def OneHot(self) -> type[OneHot]: + ... + + @property + def Embedding(self) -> type[Embedding]: + ... + + @property + def Dropout(self) -> type[Dropout]: + ... + + @property + def Downsample1D(self) -> type[Downsample1D]: + ... + + @property + def Upsample1D(self) -> type[Upsample1D]: + ... + + @property + def CheckpointName(self) -> type[CheckpointName]: + ... + + @property + def Lambda(self) -> type[Lambda]: + ... + + @property + def Logging(self) -> type[Logging]: + ... + + +__all__ = [ + name + for name, attr in globals().items() + if isinstance(attr, type) and not name.startswith('_') +] diff --git a/sequence_layers/specs/simple_behaviors.py b/sequence_layers/specs/simple_behaviors.py new file mode 100644 index 0000000..195624a --- /dev/null +++ b/sequence_layers/specs/simple_behaviors.py @@ -0,0 +1,796 @@ +"""Behavior tests for simple layers. + +Backend-specific test files should inherit from these tests. +""" + +# pylint: disable=abstract-method +# pyrefly: disable=bad-instantiation + +from fractions import Fraction +from typing import Any, override +from unittest import mock + +from absl import logging +from absl.testing import parameterized +import numpy as np + +from sequence_layers.specs import simple as simple_spec +from sequence_layers.specs import test_utils + + +class ModuleSpecTest(test_utils.ModuleSpecTest): + """Test that a backend-specific module implements the ModuleSpec protocol.""" + + @override + def module_spec_pairs(self, backend_sl: Any) -> dict[Any, Any]: + return {backend_sl.simple: simple_spec.ModuleSpec} + + +class IdentityTest(test_utils.SequenceLayerTest): + """Test behavior of Identity layer.""" + + def test_defaults(self): + # pyrefly: ignore [missing-attribute] + self.assertConfigDefaults(self.sl.Identity.Config, {'name': None}) + + @parameterized.parameters((((2, 3, 5)),), (((2, 3, 5, 9)),)) + def test_identity(self, shape): + x = self.random_sequence(*shape) + # pyrefly: ignore [missing-attribute] + l = self.sl.Identity.Config(name='identity').make() + l = self.init_layer(l, x) + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.verify_contract(l, x, training=False) + + +class PointwiseMathTest(test_utils.SequenceLayerTest): + """Test behavior of pointwise math layers.""" + + def test_defaults(self): + # pyrefly: ignore [missing-attribute] + for layer_cls in [self.sl.Abs, self.sl.Exp, self.sl.Log]: + with self.subTest(layer=layer_cls.__name__): + self.assertConfigDefaults(layer_cls.Config, {'name': None}) + + def make_layer(self, layer_name): + """Helper to create a layer by name.""" + layer_cls = getattr(self.sl, layer_name) + return layer_cls.Config(name=layer_name.lower()).make() + + def test_pointwise_math(self): + params = [ + ('Relu', 'relu', False), + ('Sigmoid', 'sigmoid', False), + ('Tanh', 'tanh', False), + ('Elu', 'elu', False), + ('Softplus', 'softplus', False), + ('Swish', 'swish', False), + ('Gelu', 'gelu', False), + ('Abs', 'abs', True), + ('Exp', 'exp', True), + ('Log', 'log', True), + ('Softmax', 'softmax', False), + ] + for layer_name, method_name, is_xp in params: + with self.subTest(layer=layer_name): + x = self.random_sequence(2, 10, 4) + l = self.make_layer(layer_name) + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + y = self.verify_contract(l, x, training=False) + + activation = getattr( + self.sl.backend.xp if is_xp else self.nn, method_name + ) + y_expected = x.apply_values(activation).mask_invalid() + self.assertSequencesClose(y, y_expected, rtol=1e-5, atol=1e-5) + + @parameterized.parameters( + ('Softmax', 'softmax', -1), + ('Softmax', 'softmax', -2), + ('Softmax', 'softmax', 2), + ('Softmax', 'softmax', 3), + ) + def test_pointwise_math_axis(self, layer_name, method_name, axis): + batch_size, time, channels, channels2 = 2, 10, 4, 3 + x = self.random_sequence(batch_size, time, channels, channels2) + l = getattr(self.sl, layer_name).Config(name='test', axis=axis).make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.get_output_shape_for_sequence(x), (channels, channels2)) + self.assertEqual(l.name, 'test') + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + activation = getattr(self.nn, method_name) + y_expected = x.apply_values( + lambda v: activation(v, axis=axis) + ).mask_invalid() + self.assertSequencesClose(y, y_expected) + + @parameterized.parameters( + ('Softmax', (2, 10, 4), -2), + ('Softmax', (2, 10, 4), -3), + ('Softmax', (2, 10, 4), 0), + ('Softmax', (2, 10, 4), 1), + ('Softmax', (2, 10), -1), + ) + def test_pointwise_math_axis_invalid(self, layer_name, shape, axis): + x = self.random_sequence(*shape) + l = getattr(self.sl, layer_name).Config(name='test', axis=axis).make() + + with self.assertRaises(ValueError): + l = self.init_layer(l, x) + l.layer(x, training=False) + + +class Downsample1DTest(test_utils.SequenceLayerTest): + """Test behavior of Downsample1D layer.""" + + @parameterized.parameters(((2, 3, 5), 2), ((2, 3, 5, 9), 3)) + def test_downsample1d(self, shape, rate): + x = self.random_sequence(*shape) + # pyrefly: ignore [missing-attribute] + l = self.sl.Downsample1D.Config(rate=rate, name='downsample_1d').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, rate) + self.assertEqual(l.output_ratio, Fraction(1, rate)) + + self.assertEqual(l.get_output_shape_for_sequence(x), x.channel_shape) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + np.testing.assert_array_equal(y.values, x.values[:, ::rate]) + np.testing.assert_array_equal(y.mask, x.mask[:, ::rate]) + + +class Upsample1DTest(test_utils.SequenceLayerTest): + """Test behavior of Upsample1D layer.""" + + @parameterized.parameters(((2, 3, 5), 2), ((2, 3, 5, 9), 3)) + def test_upsample1d(self, shape, rate): + x = self.random_sequence(*shape) + # pyrefly: ignore [missing-attribute] + l = self.sl.Upsample1D.Config(rate=rate, name='upsample_1d').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, rate) + + self.assertEqual(l.get_output_shape_for_sequence(x), x.channel_shape) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + for i in range(rate): + np.testing.assert_array_equal(x.values, y.values[:, i::rate]) + np.testing.assert_array_equal(x.mask, y.mask[:, i::rate]) + + +class TransposeTest(test_utils.SequenceLayerTest): + """Test behavior of Transpose layer.""" + + @parameterized.parameters( + ((2, 3, 4, 5), (2, 3), (4, 5)), + ((2, 3, 4, 5, 6), (4, 2, 3), (6, 4, 5)), + ((2, 3), None, ()), + ) + def test_transpose(self, input_shape, axes, _output_shape): + x = self.random_sequence(*input_shape) + # pyrefly: ignore [missing-attribute] + l = self.sl.Transpose.Config(axes=axes, name='transpose').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), _output_shape) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify shape and values + if axes is not None: + y_expected = x.apply_values(np.transpose, (0, 1) + axes) + else: + axes_seq = (0, 1) + tuple(range(2, x.ndim))[::-1] + y_expected = x.apply_values(np.transpose, axes_seq) + + self.assertSequencesEqual(y, y_expected) + + +class DropoutTest(test_utils.SequenceLayerTest): + """Test behavior of Dropout layer.""" + + def test_defaults(self): + self.assertConfigDefaults( + # pyrefly: ignore [missing-attribute] + self.sl.Dropout.Config, + {'rate': 0.0, 'name': None}, + ) + + def test_dropout_inference(self): + # pyrefly: ignore [missing-attribute] + l = self.sl.Dropout.Config(rate=0.5, name='dropout').make() + x = self.random_sequence(2, 3, 5) + l = self.init_layer(l, x) + y = l.layer(x, training=False) + # In inference, dropout should be identity + np.testing.assert_allclose(y.values, x.values) + + +class FlattenTest(test_utils.SequenceLayerTest): + """Test behavior of Flatten layer.""" + + @parameterized.parameters( + (((2, 3, 5)),), (((2, 3, 5, 9)),), (((2, 3, 5, 9, 2)),) + ) + def test_flatten(self, shape): + x = self.random_sequence(*shape) + # pyrefly: ignore [missing-attribute] + l = self.sl.Flatten.Config(name='flatten').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + num_elements = np.prod(shape[2:]) + + self.assertEqual(l.get_output_shape_for_sequence(x), (num_elements,)) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify shape + expected_shape = shape[:2] + (num_elements,) + self.assertEqual(y.values.shape, expected_shape) + + # Verify values + y_expected = x.apply_values(np.reshape, shape[:2] + (num_elements,)) + self.assertSequencesEqual(y, y_expected) + + +class ReshapeTest(test_utils.SequenceLayerTest): + """Test behavior of Reshape layer.""" + + @parameterized.parameters( + ((2, 3, 5), (1, 5, 1)), + ((2, 3, 5, 9), (3, 3, 5)), + ((2, 3, 1), ()), + ((2, 3), (1,)), + ) + def test_reshape(self, shape, output_shape): + x = self.random_sequence(*shape) + l = self.sl.Reshape.Config(output_shape, name='reshape').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), output_shape) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify shape + expected_shape = shape[:2] + output_shape + self.assertEqual(y.values.shape, expected_shape) + + # Verify values + y_expected = x.apply_values(np.reshape, shape[:2] + output_shape) + self.assertSequencesEqual(y, y_expected) + + +class ExpandDimsTest(test_utils.SequenceLayerTest): + """Test behavior of ExpandDims layer.""" + + def test_basic(self): + x = self.random_sequence(2, 3, 4) + l = self.sl.ExpandDims.Config(axis=-1, name='expand_dims').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + y = self.verify_contract(l, x, training=False) + self.assertEqual(y.values.shape, (2, 3, 4, 1)) + + def test_output_shape(self): + l = self.sl.ExpandDims.Config(axis=0, name='expand_dims').make() + self.assertEqual(l.get_output_shape((4, 8)), (1, 4, 8)) + + +class SqueezeTest(test_utils.SequenceLayerTest): + """Test behavior of Squeeze layer.""" + + @parameterized.named_parameters( + { + 'testcase_name': 'float_input', + 'input_array': np.array([[[3]]], dtype=np.float32), + 'expected_output': np.array([[3]]), + }, + { + 'testcase_name': 'int_input', + 'input_array': np.array([[[3]]], dtype=np.int32), + 'expected_output': np.array([[3]], dtype=np.int32), + }, + { + 'testcase_name': 'no_op_input', + 'input_array': np.array([[3]], dtype=np.float32), + 'expected_output': np.array([[3]]), + }, + { + 'testcase_name': 'input_with_extra_dims', + 'input_array': np.array([[[[[3], [4]]]]], dtype=np.float32), + 'expected_output': np.array([[[3, 4]]]), + }, + ) + def test_squeeze(self, input_array, expected_output): + x = self.sl.Sequence.from_values(input_array) + l = self.sl.Squeeze.Config(name='squeeze').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual( + l.get_output_shape_for_sequence(x), expected_output.shape[2:] + ) + test_receptive_field = np.issubdtype(input_array.dtype, np.inexact) + y = self.verify_contract( + l, x, training=False, test_receptive_field=test_receptive_field + ) + self.assertEmpty(self.get_variables(l)) + + # Verify shape + self.assertEqual(y.values.shape, expected_output.shape) + + # Verify values + np.testing.assert_allclose(y.values, expected_output) + + +class ScaleTest(test_utils.SequenceLayerTest): + """Test behavior of Scale layer.""" + + @parameterized.parameters(((2, 13, 5),), ((2, 13, 5, 9),)) + def test_basic(self, shape): + x = self.random_sequence(*shape) + l = self.sl.Scale.Config(scale=2.0, name='scale').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify values + y_expected = x.apply_values(lambda v: v * 2.0) + self.assertSequencesEqual(y, y_expected) + + @parameterized.parameters(((2, 13, 5),), ((2, 13, 9, 5),)) + def test_ndarray(self, shape): + x = self.random_sequence(*shape) + l = self.sl.Scale.Config( + scale=np.arange(5, dtype=np.float32), name='scale' + ).make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify values + y_expected = x.apply_values(lambda v: v * np.arange(5, dtype=np.float32)) + self.assertSequencesEqual(y, y_expected) + + def test_broadcast(self): + x = self.random_sequence(2, 3, 5, 1) + l = self.sl.Scale.Config(scale=np.ones((5, 9))).make() + l = self.init_layer(l, x) + + self.assertEqual(l.get_output_shape_for_sequence(x), (5, 9)) + y = self.verify_contract(l, x, training=False) + self.assertEqual(y.values.shape, (2, 3, 5, 9)) + self.assertEmpty(self.get_variables(l)) + + def test_too_many_dims(self): + x = self.random_sequence(2, 3, 5, 1) + l = self.sl.Scale.Config(scale=np.ones((5, 5, 5))).make() + l = self.init_layer(l, x, bind_only=True) + with self.assertRaises(ValueError): + l.get_output_shape(x.channel_shape) + with self.assertRaises(ValueError): + l.layer(x, training=False) + + def test_broadcast_failure(self): + x = self.random_sequence(2, 3, 5, 9) + l = self.sl.Scale.Config(scale=np.ones((5,))).make() + l = self.init_layer(l, x, bind_only=True) + with self.assertRaises(ValueError): + l.get_output_shape(x.channel_shape) + with self.assertRaises(ValueError): + l.layer(x, training=False) + + +class AddTest(test_utils.SequenceLayerTest): + """Test behavior of Add layer.""" + + @parameterized.parameters((((2, 13, 5)),), (((2, 13, 5, 9)),)) + def test_add(self, shape): + x = self.random_sequence(*shape) + l = self.sl.Add.Config(-2.0, name='add').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify values + y_expected = x.apply_values(lambda v: v - 2.0).mask_invalid() + self.assertSequencesEqual(y, y_expected) + + @parameterized.parameters(((2, 13, 5),), ((2, 13, 9, 5),)) + def test_ndarray(self, shape): + x = self.random_sequence(*shape) + l = self.sl.Add.Config( + shift=np.arange(5, dtype=np.float32), name='add' + ).make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) + y = self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Verify values + y_expected = x.apply_values( + lambda v: v + np.arange(5, dtype=np.float32) + ).mask_invalid() + self.assertSequencesEqual(y, y_expected) + + def test_broadcast(self): + x = self.random_sequence(2, 3, 5, 1) + l = self.sl.Add.Config(shift=np.ones((5, 9))).make() + l = self.init_layer(l, x) + + self.assertEqual(l.get_output_shape_for_sequence(x), (5, 9)) + y = self.verify_contract(l, x, training=False) + self.assertEqual(y.values.shape, (2, 3, 5, 9)) + self.assertEmpty(self.get_variables(l)) + + def test_too_many_dims(self): + x = self.random_sequence(2, 3, 5, 1) + l = self.sl.Add.Config(shift=np.ones((5, 5, 5))).make() + l = self.init_layer(l, x, bind_only=True) + with self.assertRaises(ValueError): + l.get_output_shape(x.channel_shape) + with self.assertRaises(ValueError): + l.layer(x, training=False) + + def test_broadcast_failure(self): + x = self.random_sequence(2, 3, 5, 9) + l = self.sl.Add.Config(shift=np.ones((5,))).make() + l = self.init_layer(l, x, bind_only=True) + with self.assertRaises(ValueError): + l.get_output_shape(x.channel_shape) + with self.assertRaises(ValueError): + l.layer(x, training=False) + + +class CastTest(test_utils.SequenceLayerTest): + """Test behavior of Cast layer.""" + + @parameterized.parameters( + (((2, 3, 5)), np.float16), + (((2, 3, 5, 9)), np.int32), + ) + def test_cast(self, shape, target_dtype): + x = self.random_sequence(*shape, dtype=np.float32) + l = self.sl.Cast.Config(target_dtype, name='cast').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) + test_receptive_field = np.issubdtype(target_dtype, np.inexact) + + pad_value = np.nan if target_dtype == np.float16 else 32768 + + y = self.verify_contract( + l, + x, + training=False, + padding_invariance_pad_value=pad_value, + test_receptive_field=test_receptive_field, + ) + self.assertEmpty(self.get_variables(l)) + + self.assertEqual(y.values.dtype, target_dtype) + + +class MaskInvalidTest(test_utils.SequenceLayerTest): + """Test behavior of MaskInvalid layer.""" + + def test_basic(self): + x = self.random_sequence(2, 15, 5) + l = self.sl.MaskInvalid.Config(name='mask_invalid').make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + self.assertEqual(l.get_output_shape_for_sequence(x), (5,)) + self.verify_contract(l, x, training=False) + self.assertEmpty(self.get_variables(l)) + + # Now test specific behavior + # Fill invalid values with NaN + x_nan = x.mask_invalid(np.nan) + + # Apply layer + y = l.layer(x_nan, training=False) + + # Verify that invalid values are masked (zeroed) + self.assertSequencesEqual(x.mask_invalid(), y) + + +class GatedUnitTest(test_utils.SequenceLayerTest): + """Test behavior of GatedUnit layers.""" + + def test_gated_activation(self): + shapes = ((2, 13, 6), (2, 13, 5, 10)) + + configs = [ + self.sl.GatedUnit.Config(None, None), # Bilinear + self.sl.GatedUnit.Config(None, self.nn.swish), # SwiGLU + self.sl.GatedUnit.Config(None, self.nn.gelu), # GeGLU + self.sl.GatedUnit.Config(lambda x: x, None), # Bilinear + self.sl.GatedUnit.Config(self.nn.swish, self.nn.tanh), + self.sl.GatedTanhUnit.Config(), + self.sl.GatedLinearUnit.Config(), + ] + + for shape in shapes: + for l_config in configs: + with self.subTest(shape=shape, config=str(l_config)): + x = self.random_sequence(*shape) + l = l_config.make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual( + l.get_output_shape_for_sequence(x), + shape[2:-1] + (shape[-1] // 2,), + ) + self.verify_contract(l, x, training=True) + + +class OneHotTest(test_utils.SequenceLayerTest): + """Test behavior of OneHot layer.""" + + @parameterized.parameters(((1, 2, 3),), ((2, 3, 5, 9),), ((2, 3, 5, 9, 2),)) + def test_one_hot(self, shape): + depth = 4 + l = self.sl.OneHot.Config(depth, name='one_hot').make() + x = self.random_sequence(*shape, dtype=self.xp.int32, low=0, high=depth - 1) + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:] + (depth,)) + self.assertEqual(l.name, 'one_hot') + + l = self.init_layer(l, x) + + y = self.verify_contract( + l, + x, + training=False, + padding_invariance_pad_value=0, + test_gradients=False, + test_receptive_field=False, + ) + self.assertAllEqual( + y.values, + ( + np.eye(depth)[np.array(x.values)].T + * np.array(x.mask).astype(np.float32).T + ).T, + ) + + +class EmbeddingTest(test_utils.SequenceLayerTest): + """Test behavior of Embedding layer.""" + + def test_defaults(self): + self.assertConfigDefaults( + self.sl.Embedding.Config, + {'dimension': 10, 'num_embeddings': 100, 'name': None}, + dimension=10, + num_embeddings=100, + ) + + def test_embedding(self): + shapes = [(1, 2, 3), (2, 3, 5, 9)] + dimension, num_embeddings = 8, 5 + + for shape in shapes: + with self.subTest(shape=shape): + l = self.sl.Embedding.Config( + dimension=dimension, num_embeddings=num_embeddings, name='embedding' + ).make() + x = self.random_sequence( + *shape, dtype=self.xp.int32, low=0, high=num_embeddings - 1 + ) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual( + l.get_output_shape(x.channel_shape), shape[2:] + (dimension,) + ) + + l = self.init_layer(l, x) + + self.verify_contract( + l, + x, + training=False, + test_gradients=False, + test_receptive_field=False, + ) + + +class LambdaTest(test_utils.SequenceLayerTest): + """Test behavior of Lambda layer.""" + + @parameterized.parameters(True, False) + def test_array_fn(self, mask_required: bool): + def fn(v): + if mask_required: + # Change the masked status by adding 1. + v = v + 1.0 + return v.reshape(v.shape + (1,)) > 0.5 + + l = self.sl.simple.Lambda.Config( + fn, + mask_required=mask_required, + expected_input_spec=self.sl.types.ChannelSpec((5,), self.xp.float32), + name='lambda', + ).make() + + x = self.random_sequence(2, 3, 5) + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + # Output spec reflects the changed shape and dtype. + self.assertEqual(l.get_output_shape(x.channel_shape), (5, 1)) + self.assertEqual(l.get_output_dtype(x.dtype), self.xp.bool_) + + y = self.verify_contract( + l, + x, + training=False, + # Receptive field test is not supported for bools. + test_receptive_field=False, + ) + + self.assertSequencesClose(y, x.apply_values(fn).mask_invalid()) + + @parameterized.parameters(True, False) + def test_sequence_fn(self, mask_required: bool): + def fn(x): + if mask_required: + # Change the masked status by adding 1. + x = x.apply_values(lambda v: v + 1.0) + return x.apply_values_masked(lambda v: v.reshape(v.shape + (1,)) > 0.5) + + l = self.sl.simple.Lambda.Config( + fn, + sequence_input=True, + expected_input_spec=self.sl.types.ChannelSpec((5,), self.xp.float32), + name='lambda', + ).make() + + x = self.random_sequence(2, 3, 5) + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + + # Output spec reflects the changed shape and dtype. + self.assertEqual(l.get_output_shape(x.channel_shape), (5, 1)) + self.assertEqual(l.get_output_dtype(x.dtype), self.xp.bool_) + + y = self.verify_contract( + l, + x, + training=False, + # Receptive field test is not supported for bools. + test_receptive_field=False, + ) + + self.assertSequencesClose(y, fn(x).mask_invalid()) + + +class CheckpointNameTest(test_utils.SequenceLayerTest): + """Test behavior of CheckpointName layer.""" + + def test_basic(self): + x = self.random_sequence(2, 3, 5) + l = self.sl.simple.CheckpointName.Config( + checkpoint_name='test', name='checkpoint_name' + ).make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.get_output_shape(x.channel_shape), (5,)) + self.verify_contract(l, x, training=False) + + +# pylint: disable=missing-function-docstring +class Has: + """A simple `HAS(v)` matcher that tests whether something has `v` in it.""" + + def __init__(self, value): + self._v = value + + @override + def __eq__(self, o): + return self._v in o + + @override + def __ne__(self, o): + return not self == o + + @override + def __repr__(self): + return f'' + + +class Not: + """Negates a matcher.""" + + def __init__(self, matcher): + self._matcher = matcher + + @override + def __eq__(self, o): + return self._matcher != o + + @override + def __ne__(self, o): + return not self == o + + @override + def __repr__(self): + return f'' + + +class LoggingTest(test_utils.SequenceLayerTest): + """Test behavior of Logging layer.""" + + @mock.patch.object(logging, 'info', wraps=logging.info) + def test_logs_tensors(self, mock_logger): + x = self.sl.types.Sequence.from_values(self.xp.array([[1.414, 2, 3, 4]])) + training = False + + with self.subTest('prefix'): + l = self.sl.simple.Logging.Config(prefix='test string').make() + l = self.init_layer(l, x, bind_only=True) + l.layer(x, training=training) + mock_logger.assert_called_with(Has('test string')) diff --git a/sequence_layers/specs/test_utils.py b/sequence_layers/specs/test_utils.py index b20e529..d3e7dce 100644 --- a/sequence_layers/specs/test_utils.py +++ b/sequence_layers/specs/test_utils.py @@ -12,6 +12,7 @@ from sequence_layers import specs from sequence_layers.specs import backend as backend_spec from sequence_layers.specs import types as types_spec + _T = TypeVar('_T') @@ -19,7 +20,6 @@ class _AbcParameterizedTestCaseMeta(abc.ABCMeta, type(parameterized.TestCase)): """Metaclass for abstract parameterized test cases.""" - def zip_longest( targets: Iterable[Iterable[Any]], sources: Iterable[_T], @@ -145,6 +145,11 @@ def xp(self) -> backend_spec.xp: """Returns the backend wrapper.""" return self.sl.backend.xp + @property + def nn(self) -> backend_spec.nn: + """Returns the backend nn wrapper.""" + return self.sl.backend.nn + @abc.abstractmethod def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: # pylint: disable=invalid-name """Asserts that two sequences are equal.""" @@ -153,6 +158,10 @@ def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: # pylint: d def assertAllEqual(self, x: Any, y: Any) -> None: # pylint: disable=invalid-name """Asserts that all elements are equal.""" + def get_variables(self, layer: SequenceLayerT) -> dict[str, Any]: + """Returns the variables or parameters of the layer.""" + raise NotImplementedError + @abc.abstractmethod def random_sequence( self, @@ -167,6 +176,22 @@ def random_sequence( ) -> SequenceT: """Generates a random sequence.""" + @abc.abstractmethod + def init_layer( + self, + layer: types_spec.SequenceLayer, + x: types_spec.Sequence, + bind_only: bool = False, + ) -> types_spec.SequenceLayer: + """Initializes and binds a SequenceLayer for testing. + + Args: + layer: Layer to initialize and bind. + x: Example input sequence to use for initialization. + bind_only: If True, skip initialization and only bind the layer (if + applicable to the backend). + """ + @abc.abstractmethod def _step_by_step( self, @@ -199,6 +224,18 @@ def verify_contract( def assertSequencesClose(self, x: Any, y: Any, **kwargs) -> None: # pylint: disable=invalid-name """Asserts that two sequences are close.""" + def assertConfigDefaults( # pylint: disable=invalid-name + self, config_cls: type, expected_defaults: dict[str, Any], **kwargs + ) -> None: + """Helper to verify that a config class has the expected defaults.""" + config = config_cls(**kwargs) + for field_name, expected_val in expected_defaults.items(): + self.assertEqual( + getattr(config, field_name), + expected_val, + f'Default for {field_name} in {config_cls.__name__} does not match!', + ) + class ModuleSpecTest(SequenceLayerTest): """Test that a backend-specific module implements the ModuleSpec protocol.""" @@ -218,6 +255,8 @@ def test_module_spec_with_typeguard(self) -> None: typeguard.check_type('backend_module', mod, protocol) +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring @runtime_checkable class ModuleSpec(Protocol): """Specification for sequence_layers..test_utils""" @@ -227,17 +266,17 @@ def zip_longest( targets: Iterable[Iterable[Any]], sources: Iterable[Any], ) -> list[Any]: - """Zips targets and sources.""" + ... def named_product( self, first: Iterable[Any], second: Iterable[Any], ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - """Creates a named product.""" + ... @property - def SequenceLayerTest(self) -> type: # pylint: disable=invalid-name + def SequenceLayerTest(self) -> type: ... diff --git a/sequence_layers/specs/types.py b/sequence_layers/specs/types.py index 5d5e8be..4c89080 100644 --- a/sequence_layers/specs/types.py +++ b/sequence_layers/specs/types.py @@ -28,6 +28,20 @@ DType = Any # Can be numpy, jax, or mlx dtype +@runtime_checkable +class HashableArray(Protocol): + """Protocol for hashable multidimensional arrays.""" + + data: complex | tuple[Any, ...] + """The data as a tuple or complex scalar.""" + + dtype: Any + """The dtype of the array.""" + + def to_array(self) -> Any: + """Returns the array representation.""" + + class ChannelSpec(Protocol): """Protocol for channel specifications.""" @@ -39,6 +53,9 @@ def shape(self) -> Shape: def dtype(self) -> Any: """The dtype of the channel.""" + def __init__(self, shape: Shape, dtype: Any): + ... + State = Any Constants = MutableMapping[str, jt.PyTree[Array]] @@ -172,7 +189,13 @@ class PaddingMode(enum.Enum): class Sequence[ValuesT = Array, MaskT = Array](metaclass=abc.ABCMeta): - """Abstract base class for Sequence.""" + """A generic sequence container that preserves masking information. + + Note: This class can hold non-backend-specific arrays (like `np.ndarray`) to + maintain consistency with JAX. Backend implementations should handle them + gracefully, for example by converting to backend-native arrays just-in-time + when backend-specific operations require it. + """ values: ValuesT mask: MaskT @@ -327,13 +350,18 @@ def apply_masked[NewValuesT: Array, NewMaskT: Array, **P]( class SequenceLayerConfig(metaclass=abc.ABCMeta): """Configuration for a SequenceLayer.""" + def __init__(self, *args: Any, **kwargs: Any): + pass + @abc.abstractmethod def make(self) -> Any: """Creates the sequence layer.""" - @abc.abstractmethod def copy(self, **kwargs: Any) -> Self: """Returns a copy of the config with updated fields.""" + import dataclasses + + return dataclasses.replace(self, **kwargs) class Steppable[ @@ -592,6 +620,15 @@ class SequenceLayer[ ](Steppable[InputT, OutputT, ChannelSpecT], metaclass=abc.ABCMeta): """Base class for Sequence Layers.""" + @abc.abstractmethod + def get_output_shape_for_sequence( + self, + x: Sequence, + *, + constants: Constants | None = None, + ) -> Shape: + """Returns the output shape this layer produces for the provided Sequence.""" + # --------------------------------------------------------------------------- # Mixins @@ -891,59 +928,67 @@ def layer_with_emits( ... +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring @runtime_checkable class ModuleSpec(Protocol): """Specification for sequence_layers..types""" - # pylint: disable=invalid-name + @property + def ChannelSpec(self) -> type[ChannelSpec]: + ... + + @property + def HashableArray(self) -> type[HashableArray]: + ... @property def Sequence(self) -> type[Sequence]: - """The Sequence class for this backend.""" + ... @property def MaskedSequence(self) -> type[MaskedSequence]: - """The MaskedSequence class for this backend.""" + ... @property def SequenceLayer(self) -> type[SequenceLayer]: - """The SequenceLayer class for this backend.""" + ... @property def SequenceLayerConfig(self) -> type[SequenceLayerConfig]: - """The SequenceLayerConfig class for this backend.""" + ... @property def Steppable(self) -> type[Steppable]: - """The Steppable class for this backend.""" + ... @property def PreservesShape(self) -> type[PreservesShape]: - """The PreservesShape class for this backend.""" + ... @property def Stateless(self) -> type[Stateless]: - """The Stateless class for this backend.""" + ... @property def StatelessPointwise(self) -> type[StatelessPointwise]: - """The StatelessPointwise class for this backend.""" + ... @property def StatelessPointwiseFunctor(self) -> type[StatelessPointwiseFunctor]: - """The StatelessPointwiseFunctor class for this backend.""" + ... @property def PreservesType(self) -> type[PreservesType]: - """The PreservesType class for this backend.""" + ... @property def Emitting(self) -> type[Emitting]: - """The Emitting class for this backend.""" + ... @property def StatelessEmitting(self) -> type[StatelessEmitting]: - """The StatelessEmitting class for this backend.""" + ... __all__ = [ diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index 891bc56..b9a7d6c 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -36,7 +36,6 @@ def test_backend_specific_types_are_subclasses(self) -> None: self.assertTrue(issubclass(mod.Steppable, types_spec.Steppable)) - class DummyChannelSpec(NamedTuple): """Dummy channel spec for testing.""" @@ -85,6 +84,15 @@ def get_accumulated_input_latency(self, input_latency: int) -> int: def get_accumulated_output_latency(self, output_latency: int) -> int: return output_latency + @override + def get_output_shape_for_sequence( + self, + x: types_spec.Sequence, + *, + constants: types_spec.Constants | None = None, + ) -> types_spec.Shape: + return x.channel_shape + @override def layer( self, @@ -181,6 +189,23 @@ def test_backend_specific_module_has_interface(self) -> None: class SequenceTest(SequenceLayerTest): """Abstract tests for the Sequence class.""" + def test_accepts_numpy_arrays(self) -> None: + """Tests that Sequence accepts numpy arrays and doesn't convert them.""" + values_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + mask_np = np.array([[True, False], [False, True]], dtype=bool) + + x = self.sl.Sequence(values_np, mask_np) + + # Verify they are stored as numpy arrays + self.assertIsInstance(x.values, np.ndarray) + self.assertIsInstance(x.mask, np.ndarray) + + # And verify that operations like mask_invalid still work (returning backend arrays for values) + masked = x.mask_invalid() + self.assertNotIsInstance(masked.values, np.ndarray) + # We don't assert masked.mask is not ndarray, as it might stay ndarray if it + # was initialized as such, following JAX behavior. + @parameterized.named_parameters( ('mask_value=None', 0.0, None), ('mask_value=0.0', 0.0, 0.0), @@ -781,6 +806,36 @@ def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: return DummyLayer() + def test_mask_required_default(self) -> None: + """Tests that mask_required defaults to True.""" + backend_sl = self.sl + + class DefaultLayer( + DefaultTestLayer, backend_sl.types.StatelessPointwiseFunctor + ): + """Mock layer for testing defaults.""" + + def fn(self, values: Any, mask: Any) -> tuple[Any, Any]: + """Pointwise function.""" + return values, mask + + @override + def layer(self, *args, **kwargs): + """Calls base layer.""" + return backend_sl.types.StatelessPointwiseFunctor.layer( + self, *args, **kwargs + ) + + @override + def get_output_shape(self, *args, **kwargs): + """Calls base get_output_shape.""" + return backend_sl.types.StatelessPointwiseFunctor.get_output_shape( + self, *args, **kwargs + ) + + layer = DefaultLayer() + self.assertTrue(layer.mask_required) + def create_sequence( self, ) -> types_spec.Sequence[types_spec.Array, types_spec.Array]: @@ -810,3 +865,28 @@ def test_layer_applies_fn_based_on_mask_required(self) -> None: else: mock_apply_masked.assert_called_once() mock_apply.assert_not_called() + + +class HashableArrayTest(test_utils_spec.SequenceLayerTest): + """Tests for HashableArray.""" + + def test_hashable_array(self) -> None: + # We need to get HashableArray from the backend types! + HashableArray = self.sl.types.HashableArray + + # Create a numpy array + x = np.array([[1.0, 2.0], [3.0, 4.0]]) + + # Create HashableArray + ha = HashableArray.from_array(x) + + # Check properties + self.assertEqual(ha.dtype, x.dtype) + + # Check to_array + x_back = ha.to_array() + np.testing.assert_array_equal(x, x_back) + + # Check hashability + h = hash(ha) + self.assertIsInstance(h, int) From 340c8aded35a3a0de173ce0c44e9716854e944ae Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 16 Apr 2026 13:41:24 -0700 Subject: [PATCH 3/5] refactor(jax): Use shared behaviors in tests for simple layers --- pyproject.toml | 5 +- sequence_layers/jax/backend.py | 53 ++ sequence_layers/jax/simple.py | 981 +++++++++++++++++++++-------- sequence_layers/jax/simple_test.py | 678 ++------------------ sequence_layers/jax/test_utils.py | 10 +- sequence_layers/jax/types.py | 44 +- sequence_layers/jax/types_test.py | 16 +- sequence_layers/jax/typing.py | 2 +- 8 files changed, 905 insertions(+), 884 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 77991dd..d9a6672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,4 +116,7 @@ exclude = [ ] [tool.pyrefly] -errors = { missing-override-decorator = "error" } \ No newline at end of file +# Pyrefly fails to properly support config: Config without defaults in Flax Modules +# (used in JAX), incorrectly treating them as dataclasses and complaining about +# field ordering. This effectively only impacts JAX files. +errors = { missing-override-decorator = "error", bad-class-definition = "ignore" } \ No newline at end of file diff --git a/sequence_layers/jax/backend.py b/sequence_layers/jax/backend.py index 320f495..b138b44 100644 --- a/sequence_layers/jax/backend.py +++ b/sequence_layers/jax/backend.py @@ -2,6 +2,7 @@ from typing import override +import jax.nn as jnn import jax.numpy as jnp from sequence_layers.specs import backend as spec @@ -23,8 +24,60 @@ def array(self, a, dtype=None) -> types_spec.Array: def zeros(self, shape, dtype=None) -> types_spec.Array: return jnp.zeros(shape, dtype=dtype) + @override def concatenate(self, arrays, axis=0) -> types_spec.Array: return jnp.concatenate(arrays, axis=axis) + @override + def abs(self, x) -> types_spec.Array: + return jnp.abs(x) + + @override + def exp(self, x) -> types_spec.Array: + return jnp.exp(x) + + @override + def log(self, x) -> types_spec.Array: + return jnp.log(x) + xp: spec.xp = BackendWrapper() + + +class NNWrapper(spec.nn): + """Wrapper around JAX activations to match backend protocol.""" + + @override + def relu(self, x: types_spec.Array) -> types_spec.Array: + return jnn.relu(x) + + @override + def sigmoid(self, x: types_spec.Array) -> types_spec.Array: + return jnn.sigmoid(x) + + @override + def tanh(self, x: types_spec.Array) -> types_spec.Array: + return jnn.tanh(x) + + @override + def swish(self, x: types_spec.Array) -> types_spec.Array: + return jnn.swish(x) + + @override + def gelu(self, x: types_spec.Array) -> types_spec.Array: + return jnn.gelu(x) + + @override + def elu(self, x: types_spec.Array) -> types_spec.Array: + return jnn.elu(x) + + @override + def softplus(self, x: types_spec.Array) -> types_spec.Array: + return jnn.softplus(x) + + @override + def softmax(self, x: types_spec.Array, axis: int = -1) -> types_spec.Array: + return jnn.softmax(x, axis=axis) + + +nn: spec.nn = NNWrapper() diff --git a/sequence_layers/jax/simple.py b/sequence_layers/jax/simple.py index 8d9340a..3efb6f2 100644 --- a/sequence_layers/jax/simple.py +++ b/sequence_layers/jax/simple.py @@ -20,7 +20,8 @@ import functools import math import typing -from typing import Any, Callable, Sequence as TypingSequence +from typing import Any, Callable, override +from typing import Sequence as TypingSequence from absl import logging import einops @@ -29,10 +30,15 @@ import jax.ad_checkpoint import jax.numpy as jnp import numpy as np +from typing_extensions import override + from sequence_layers.jax import sharding as sharding_lib from sequence_layers.jax import types +from sequence_layers.jax import typing as jt from sequence_layers.jax import utils -from typing_extensions import override +from sequence_layers.jax.types import MaskT +from sequence_layers.jax.types import ValuesT +from sequence_layers.specs import simple as spec # pylint: disable=logging-fstring-interpolation @@ -101,15 +107,13 @@ def _to_tuple(x: complex | list[Any]) -> complex | tuple[Any, ...]: - """Replaces lists in a pytree of complex with tuples.""" if isinstance(x, list): - return tuple(_to_tuple(i) for i in x) - else: - return x + return tuple(_to_tuple(item) for item in x) + return x @dataclasses.dataclass(frozen=True) -class HashableArray: +class HashableArray(spec.HashableArray): """Hashable multidimensional array of tuples.""" data: complex | tuple[Any, ...] @@ -120,6 +124,7 @@ def from_array(cls, x: np.ndarray) -> 'HashableArray': x = np.asarray(x) return HashableArray(_to_tuple(x.tolist()), x.dtype) + @override def to_array(self) -> np.ndarray: return np.asarray(self.data, dtype=self.dtype) @@ -150,6 +155,7 @@ def _validate( f' with the input channel shape ({input_shape=}).' ) + @override @nn.nowrap def get_output_shape( self, @@ -163,11 +169,13 @@ def get_output_shape( return jnp.broadcast_shapes(input_shape, parameter.shape) -class Scale(StatelessPointwiseBroadcasting): +class Scale( + StatelessPointwiseBroadcasting, spec.Scale[types.Sequence, types.ShapeDType] +): """Scales the input by a provided constant or array.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Scale.Config): """Config for Scale.""" # The value to scale the input by. May be a numpy array, but must be @@ -178,18 +186,26 @@ class Config(types.SequenceLayerConfig): name: str | None = None def __post_init__(self): - object.__setattr__(self, 'scale', HashableArray.from_array(self.scale)) + object.__setattr__( + self, + 'scale', + HashableArray.from_array(typing.cast(typing.Any, self.scale)), + ) + @override def make(self) -> 'Scale': - return Scale(self, name=self.name) + return Scale(config=self, name=self.name) config: Config @property + @override def _broadcast_parameter(self) -> np.ndarray: return _to_array(self.config.scale) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -227,9 +243,11 @@ def __post_init__(self): self, 'shape', [] if self.shape is None else self.shape ) + @override def make(self) -> 'Affine': - return Affine(self, name=self.name) + return Affine(config=self, name=self.name) + @override def setup(self): cfg = self.config if cfg.use_scale: @@ -249,6 +267,7 @@ def setup(self): config: Config + @override @nn.nowrap def get_output_shape( self, @@ -257,6 +276,7 @@ def get_output_shape( constants: types.Constants | None = None, ) -> types.Shape: del constants + assert self.config.shape is not None # Check that the parameters do not have batch or time dimension. if len(input_shape) < len(self.config.shape): @@ -268,7 +288,9 @@ def get_output_shape( # This function throws a value error if the shapes are not broadcastable. return jnp.broadcast_shapes(input_shape, self.config.shape) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -288,11 +310,13 @@ def layer( return x -class Add(StatelessPointwiseBroadcasting): +class Add( + StatelessPointwiseBroadcasting, spec.Add[types.Sequence, types.ShapeDType] +): """Adds the provided constant or array to the input.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Add.Config): """Config for Add.""" # The value to add to the input. May be a numpy array, but must be @@ -303,18 +327,26 @@ class Config(types.SequenceLayerConfig): name: str | None = None def __post_init__(self): - object.__setattr__(self, 'shift', HashableArray.from_array(self.shift)) + object.__setattr__( + self, + 'shift', + HashableArray.from_array(typing.cast(typing.Any, self.shift)), + ) + @override def make(self) -> 'Add': - return Add(self, name=self.name) + return Add(config=self, name=self.name) config: Config @property + @override def _broadcast_parameter(self) -> np.ndarray: return _to_array(self.config.shift) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -345,19 +377,25 @@ class Config(types.SequenceLayerConfig): def __post_init__(self): object.__setattr__( - self, 'maximum', HashableArray.from_array(self.maximum) + self, + 'maximum', + HashableArray.from_array(typing.cast(typing.Any, self.maximum)), ) + @override def make(self) -> 'Maximum': - return Maximum(self, name=self.name) + return Maximum(config=self, name=self.name) config: Config @property + @override def _broadcast_parameter(self) -> np.ndarray: return _to_array(self.config.maximum) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -388,19 +426,25 @@ class Config(types.SequenceLayerConfig): def __post_init__(self): object.__setattr__( - self, 'divisor', HashableArray.from_array(self.divisor) + self, + 'divisor', + HashableArray.from_array(typing.cast(typing.Any, self.divisor)), ) + @override def make(self) -> 'Mod': - return Mod(self, name=self.name) + return Mod(config=self, name=self.name) config: Config @property + @override def _broadcast_parameter(self) -> np.ndarray: return _to_array(self.config.divisor) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -433,19 +477,25 @@ class Config(types.SequenceLayerConfig): def __post_init__(self): object.__setattr__( - self, 'minimum', HashableArray.from_array(self.minimum) + self, + 'minimum', + HashableArray.from_array(typing.cast(typing.Any, self.minimum)), ) + @override def make(self) -> 'Minimum': - return Minimum(self, name=self.name) + return Minimum(config=self, name=self.name) config: Config @property + @override def _broadcast_parameter(self) -> np.ndarray: return _to_array(self.config.minimum) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -483,6 +533,7 @@ def __post_init__(self): else: object.__setattr__(self, 'axis', tuple(self.axis)) + @override def make(self) -> '_ReduceChannels': raise NotImplementedError() @@ -494,6 +545,7 @@ def _reduce_fn(self) -> Callable[..., jax.Array]: ... @property + @override def supports_step(self) -> bool: return True @@ -503,6 +555,7 @@ def _validate_axis(self, input_shape: types.ShapeLike) -> tuple[int, ...]: rank = len(input_shape) + 2 axis = self.config.axis if axis is not None: + # pyrefly: ignore[not-iterable] axis = [a + rank if a < 0 else a for a in axis] else: axis = list(range(2, rank)) @@ -514,6 +567,7 @@ def _validate_axis(self, input_shape: types.ShapeLike) -> tuple[int, ...]: ) return tuple(axis) + @override @nn.nowrap def get_output_shape( self, @@ -528,7 +582,9 @@ def get_output_shape( else: return tuple(d for i, d in enumerate(input_shape) if i + 2 not in axis) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -550,10 +606,12 @@ class Mean(_ReduceChannels): class Config(_ReduceChannels.Config): """Config for Mean.""" + @override def make(self) -> 'Mean': - return Mean(self, name=self.name) + return Mean(config=self, name=self.name) @property + @override def _reduce_fn(self) -> Callable[..., jax.Array]: return jnp.mean @@ -565,10 +623,12 @@ class Min(_ReduceChannels): class Config(_ReduceChannels.Config): """Config for Min.""" + @override def make(self) -> 'Min': - return Min(self, name=self.name) + return Min(config=self, name=self.name) @property + @override def _reduce_fn(self) -> Callable[..., jax.Array]: return jnp.min @@ -580,10 +640,12 @@ class Max(_ReduceChannels): class Config(_ReduceChannels.Config): """Config for Max.""" + @override def make(self) -> 'Max': - return Max(self, name=self.name) + return Max(config=self, name=self.name) @property + @override def _reduce_fn(self) -> Callable[..., jax.Array]: return jnp.max @@ -595,38 +657,50 @@ class Sum(_ReduceChannels): class Config(_ReduceChannels.Config): """Config for Sum.""" + @override def make(self) -> 'Sum': - return Sum(self, name=self.name) + return Sum(config=self, name=self.name) @property + @override def _reduce_fn(self) -> Callable[..., jax.Array]: return jnp.sum -class Abs(types.StatelessPointwiseFunctor): +class Abs( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Abs[types.Sequence, types.ShapeDType], +): """Absolute value layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Abs.Config): name: str | None = None + @override def make(self) -> 'Abs': - return Abs(self, name=self.name) + return Abs(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( + def fn[ValuesT: jt.ArrayT, MaskT: jt.ArrayT]( self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + + # pyrefly: ignore[bad-argument-type] return jnp.abs(values), mask + @override @nn.nowrap def get_output_dtype( self, @@ -644,31 +718,43 @@ def get_output_dtype( return input_dtype -class Cast(types.StatelessPointwiseFunctor): +class Cast( + types.StatelessPointwiseFunctor, + spec.Cast[types.Sequence, types.ShapeDType], +): """Cast input values to the specified type.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Cast.Config): dtype: types.DType name: str | None = None + @override def make(self) -> 'Cast': - return Cast(self, name=self.name) + return Cast(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] + # pyrefly: ignore[missing-attribute] return values.astype(self.config.dtype), mask + @override @nn.nowrap def get_output_dtype( self, @@ -679,21 +765,30 @@ def get_output_dtype( return self.config.dtype -class GatedUnit(types.PreservesType, types.Stateless): +class GatedUnit( + types.PreservesType, + types.Stateless, + spec.GatedUnit[types.Sequence, types.ShapeDType], +): """Computes a generalized Gated Unit, reducing the input channels by 2x.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): - feature_activation: Callable[[types.ValuesT], types.ValuesT] | None - gate_activation: Callable[[types.ValuesT], types.ValuesT] | None + class Config(spec.GatedUnit.Config): + feature_activation: Callable[[types.ArrayLike], types.ArrayLike] | None = ( + None + ) + gate_activation: Callable[[types.ArrayLike], types.ArrayLike] | None = None name: str | None = None + @override def make(self) -> 'GatedUnit': - return GatedUnit(self, name=self.name) + return GatedUnit(config=self, name=self.name) config: Config + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -706,9 +801,11 @@ def layer( feature = self.config.feature_activation(feature) if self.config.gate_activation: gate = self.config.gate_activation(gate) + # pyrefly: ignore[unsupported-operation] values = feature * gate return types.Sequence(values, x.mask) + @override @nn.nowrap def get_output_shape( self, @@ -725,29 +822,53 @@ def get_output_shape( return tuple(input_shape[:-1]) + (channels // 2,) -class GatedLinearUnit(GatedUnit): +class GatedLinearUnit( + GatedUnit, spec.GatedLinearUnit[types.Sequence, types.ShapeDType] +): """Computes a Gated Linear Unit, reducing the input channels by 2x.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(GatedUnit.Config, spec.GatedLinearUnit.Config): name: str | None = None + @override def make(self) -> 'GatedLinearUnit': return GatedLinearUnit( - GatedUnit.Config(None, jax.nn.sigmoid, name=self.name), name=self.name + config=GatedUnit.Config( + None, + typing.cast( + typing.Callable[[types.ArrayLike], types.ArrayLike], + jax.nn.sigmoid, + ), + name=self.name, + ), + name=self.name, ) -class GatedTanhUnit(GatedUnit): +class GatedTanhUnit( + GatedUnit, spec.GatedTanhUnit[types.Sequence, types.ShapeDType] +): """Computes a Gated Tanh Unit, reducing the input channels by 2x.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(GatedUnit.Config, spec.GatedTanhUnit.Config): name: str | None = None + @override def make(self) -> 'GatedTanhUnit': return GatedTanhUnit( - GatedUnit.Config(jax.nn.tanh, jax.nn.sigmoid, name=self.name), + config=GatedUnit.Config( + typing.cast( + typing.Callable[[types.ArrayLike], types.ArrayLike], + jax.nn.tanh, + ), + typing.cast( + typing.Callable[[types.ArrayLike], types.ArrayLike], + jax.nn.sigmoid, + ), + name=self.name, + ), name=self.name, ) @@ -760,13 +881,16 @@ class Config(types.SequenceLayerConfig): clip_value: float name: str | None = None + @override def make(self) -> 'GradientClipping': assert self.clip_value > 0 - return GradientClipping(self, name=self.name) + return GradientClipping(config=self, name=self.name) config: Config + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -792,17 +916,24 @@ def _custom_gradient(input_gradients): return x.apply_values_masked(_clip_gradient) -class Identity(types.PreservesType, types.StatelessPointwise): +class Identity( + types.PreservesType, + types.StatelessPointwise, + spec.Identity[types.Sequence, types.ShapeDType], +): """Identity pass-through of the input.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Identity.Config): name: str | None = None + @override def make(self) -> 'Identity': return Identity(name=self.name) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -823,12 +954,15 @@ class Config(types.SequenceLayerConfig): mask_sharding: types.Sharding | None = None name: str | None = None + @override def make(self) -> 'ApplySharding': - return ApplySharding(self, name=self.name) + return ApplySharding(config=self, name=self.name) config: Config + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -861,12 +995,15 @@ class Config(types.SequenceLayerConfig): apply_to_mask: bool = False name: str | None = None + @override def make(self) -> 'OptimizationBarrier': - return OptimizationBarrier(self, name=self.name) + return OptimizationBarrier(config=self, name=self.name) config: Config + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -885,7 +1022,10 @@ def shard_values_mask(values, mask): return x.apply_masked(shard_values_mask) -class Lambda(types.Stateless): +class Lambda( + types.Stateless, + spec.Lambda[types.Sequence, types.ShapeDType], +): """A SequenceLayer that wraps a Python lambda function. The wrapped lambda is assumed to be stateless. The receptive field of the @@ -894,7 +1034,7 @@ class Lambda(types.Stateless): """ @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Lambda.Config): """Configuration for a Lambda layer.""" # If sequence_input is True, a callable that takes an sl.Sequence and @@ -917,12 +1057,14 @@ class Config(types.SequenceLayerConfig): # An optional name for the layer. name: str | None = None + @override def make(self) -> 'Lambda': - return Lambda(self, name=self.name) + return Lambda(config=self, name=self.name) config: Config @property + @override def supports_step(self) -> bool: return True @@ -940,14 +1082,17 @@ def _validate_input_spec(self, input_spec: types.ShapeDType) -> None: # f' input spec {expected_input_spec=}' # ) + @override def get_output_spec( self, - input_spec: types.ChannelSpec, + input_spec: types.ShapeDType, *, constants: types.Constants | None = None, - ) -> types.ChannelSpec: + ) -> types.ShapeDType: self._validate_input_spec(input_spec) if self.config.sequence_input: + # pyrefly: ignore[bad-assignment] + # pyrefly: ignore[bad-specialization] input_spec = types.Sequence( types.ShapeDType( (1, 1) + tuple(input_spec.shape), @@ -963,6 +1108,7 @@ def get_output_spec( output_spec = jax.eval_shape(self.config.fn, input_spec) return jax.ShapeDtypeStruct(output_spec.shape[2:], output_spec.dtype) + @override @nn.nowrap def get_output_dtype( self, @@ -981,6 +1127,7 @@ def get_output_dtype( ) ).dtype + @override @nn.nowrap def get_output_shape( self, @@ -999,7 +1146,9 @@ def get_output_shape( ) ).shape + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -1029,36 +1178,48 @@ def layer( f' {values.shape=}' ) if self.config.mask_required: + # pyrefly: ignore[bad-specialization] y = types.Sequence(values, x.mask) else: + # pyrefly: ignore[bad-specialization] y = type(x)(values, x.mask) return y -class CheckpointName(types.PreservesType, types.StatelessPointwiseFunctor): +class CheckpointName( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.CheckpointName[types.Sequence, types.ShapeDType], +): """Applies a checkpoint name to the sequence values.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.CheckpointName.Config): checkpoint_name: str name: str | None = None + @override def make(self) -> 'CheckpointName': - return CheckpointName(self, name=self.name) + return CheckpointName(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: values = jax.ad_checkpoint.checkpoint_name( values, self.config.checkpoint_name ) @@ -1083,21 +1244,27 @@ class Config(types.SequenceLayerConfig): param_dtype: types.DType = jnp.float32 name: str | None = None + @override def make(self) -> 'Snake': - return Snake(self, name=self.name) + return Snake(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.compact - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: channel_shape = values.shape[2:] alpha_log = self.param( 'alpha_log', @@ -1105,6 +1272,7 @@ def fn( channel_shape, self.config.param_dtype, ) + # pyrefly: ignore[bad-argument-type] alpha = jnp.exp(alpha_log)[jnp.newaxis, jnp.newaxis, ...] if self.config.separate_beta: beta_log = self.param( @@ -1113,85 +1281,123 @@ def fn( channel_shape, self.config.param_dtype, ) + # pyrefly: ignore[bad-argument-type] beta = jnp.exp(beta_log)[jnp.newaxis, jnp.newaxis, ...] else: beta = alpha + # pyrefly: ignore[bad-argument-type] + # pyrefly: ignore[unsupported-operation] values += jnp.square(jnp.sin(values * alpha)) / (beta + 1e-12) return values, mask -class Tanh(types.PreservesType, types.StatelessPointwiseFunctor): +class Tanh( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Tanh[types.Sequence, types.ShapeDType], +): """A tanh layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Tanh.Config): name: str | None = None + @override def make(self) -> 'Tanh': - return Tanh(self, name=self.name) + return Tanh(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.tanh(values), mask -class Relu(types.PreservesType, types.StatelessPointwiseFunctor): +class Relu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Relu[types.Sequence, types.ShapeDType], +): """A Relu layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Relu.Config): name: str | None = None + @override def make(self) -> 'Relu': - return Relu(name=self.name) + return Relu(config=self, name=self.name) + + config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.relu(values), mask -class LeakyRelu(types.PreservesType, types.StatelessPointwiseFunctor): +class LeakyRelu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.LeakyRelu[types.Sequence, types.ShapeDType], +): """A Leaky Relu layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.LeakyRelu.Config): negative_slope: complex = 0.01 name: str | None = None + @override def make(self) -> 'LeakyRelu': - return LeakyRelu(self, name=self.name) + return LeakyRelu(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.leaky_relu(values, self.config.negative_slope), mask @@ -1204,11 +1410,13 @@ class Config(types.SequenceLayerConfig): param_dtype: types.DType = jnp.float32 name: str | None = None + @override def make(self) -> 'PRelu': - return PRelu(self, name=self.name) + return PRelu(config=self, name=self.name) config: Config + @override def setup(self): self.negative_slope = self.param( 'negative_slope', @@ -1218,87 +1426,131 @@ def setup(self): ) @property + @override def mask_required(self) -> bool: return False + @override @nn.nowrap - def fn( + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: return ( + # pyrefly: ignore[no-matching-overload] jnp.where( + # pyrefly: ignore[unsupported-operation] values >= 0, values, + # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore[bad-argument-type] self.negative_slope.astype(values.dtype) * values, ), mask, ) -class Elu(types.PreservesType, types.StatelessPointwiseFunctor): +class Elu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Elu[types.Sequence, types.ShapeDType], +): """An elu activation layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Elu.Config): alpha: complex = 1.0 name: str | None = None + @override def make(self) -> 'Elu': - return Elu(self, name=self.name) + return Elu(config=self, name=self.name) config: Config + @property + @override + def mask_required(self): + return False + + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.elu(values, self.config.alpha), mask -class Exp(types.PreservesType, types.StatelessPointwiseFunctor): +class Exp( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Exp[types.Sequence, types.ShapeDType], +): """An exp layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Exp.Config): name: str | None = None + @override def make(self) -> 'Exp': - return Exp(self, name=self.name) + return Exp(config=self, name=self.name) config: Config + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jnp.exp(values), mask -class Log(types.PreservesType, types.StatelessPointwiseFunctor): +class Log( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Log[types.Sequence, types.ShapeDType], +): """A log layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Log.Config): name: str | None = None + @override def make(self) -> 'Log': - return Log(self, name=self.name) + return Log(config=self, name=self.name) config: Config + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jnp.log(values), mask @@ -1310,140 +1562,203 @@ class Config(types.SequenceLayerConfig): power: float = 1.0 name: str | None = None + @override def make(self) -> 'Power': - return Power(self, name=self.name) + return Power(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jnp.power(values, self.config.power), mask -class Sigmoid(types.PreservesType, types.StatelessPointwiseFunctor): +class Sigmoid( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Sigmoid[types.Sequence, types.ShapeDType], +): """A sigmoid layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Sigmoid.Config): name: str | None = None + @override def make(self) -> 'Sigmoid': - return Sigmoid(self, name=self.name) + return Sigmoid(config=self, name=self.name) config: Config + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.sigmoid(values), mask -class Softplus(types.PreservesType, types.StatelessPointwiseFunctor): +class Softplus( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Softplus[types.Sequence, types.ShapeDType], +): """A softplus layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Softplus.Config): name: str | None = None + @override def make(self) -> 'Softplus': - return Softplus(self, name=self.name) + return Softplus(config=self, name=self.name) config: Config + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.softplus(values), mask -class Softmax(types.PreservesType, types.StatelessPointwiseFunctor): +class Softmax( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Softmax[types.Sequence, types.ShapeDType], +): """A softmax layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Softmax.Config): axis: int = -1 name: str | None = None + @override def make(self) -> 'Softmax': - return Softmax(self, name=self.name) + return Softmax(config=self, name=self.name) config: Config + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: axis = self.config.axis if (axis if axis >= 0 else values.ndim + axis) < 2: raise ValueError( 'The softmax cannot be applied on the batch or time dimension (got' f' {axis=} for shape={values.shape})' ) + # pyrefly: ignore[bad-argument-type] return jax.nn.softmax(values, axis=axis), mask -class Swish(types.PreservesType, types.StatelessPointwiseFunctor): +class Swish( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Swish[types.Sequence, types.ShapeDType], +): """A Swish layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Swish.Config): name: str | None = None + @override def make(self) -> 'Swish': return Swish(name=self.name) @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + # pyrefly: ignore[missing-override-decorator] + # pyrefly: ignore[bad-argument-type] + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.swish(values), mask -class Gelu(types.PreservesType, types.StatelessPointwiseFunctor): +class Gelu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Gelu[types.Sequence, types.ShapeDType], +): """A Gaussian Error Linear Unit (GELU) layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Gelu.Config): approximate: bool = True name: str | None = None + @override def make(self) -> 'Gelu': - return Gelu(self, name=self.name) + return Gelu(config=self, name=self.name) config: Config @property + @override def mask_required(self): return False + @override @nn.nowrap - def fn( - self, - values: types.ValuesT, - mask: types.MaskT, - ) -> tuple[types.ValuesT, types.MaskT]: + # pyrefly: ignore[missing-override-decorator] + # pyrefly: ignore[bad-argument-type] + def fn[ + ValuesT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + MaskT: (jax.Array, np.ndarray, jax.ShapeDtypeStruct), + ]( + self, + values: ValuesT, + mask: MaskT, + ) -> tuple[ValuesT, MaskT]: + # pyrefly: ignore[bad-argument-type] return jax.nn.gelu(values, approximate=self.config.approximate), mask @@ -1471,13 +1786,14 @@ def __post_init__(self): # Use hashable types for sequences. object.__setattr__(self, 'slices', tuple(self.slices)) - def as_slices(self) -> tuple[slice | int | None]: + def as_slices(self) -> tuple[slice | int | None, ...]: return tuple( slice(*s) if isinstance(s, tuple) else s for s in self.slices ) + @override def make(self) -> 'Slice': - return Slice(self, name=self.name) + return Slice(config=self, name=self.name) config: Config @@ -1490,6 +1806,7 @@ def _validate_slice_for_input_shape(self, input_shape: types.ShapeLike): % (input_shape, self.config.slices) ) + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -1500,10 +1817,6 @@ def get_output_shape( output_dims = [] input_index = 0 - # Compute the output shape: - # - int: Remove the current input dimension. - # - slice: Compute the output dimension size using slice.indices. - # - None (tf.newaxis): Add a dimension. for slice_i in self.config.slices: if isinstance(slice_i, tuple): slice_i = slice(*slice_i) @@ -1525,7 +1838,9 @@ def get_output_shape( ) return tuple(output_dims) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -1533,7 +1848,6 @@ def layer( training: bool, constants: types.Constants | None = None, ) -> types.Sequence: - # Slice the batch and time dimensions with [:, :]. full_slice = ( slice(None, None, None), slice(None, None, None), @@ -1543,7 +1857,11 @@ def layer( return x.apply_values_masked(lambda v: v.__getitem__(full_slice)) -class Flatten(types.PreservesType, types.Stateless): +class Flatten( + types.PreservesType, + types.Stateless, + spec.Flatten[types.Sequence, types.ShapeDType], +): """Flattens the channel dimensions of the input sequence. An input sequence with shape [batch_size, time, ...] is reshaped to @@ -1554,9 +1872,11 @@ class Flatten(types.PreservesType, types.Stateless): class Config(types.SequenceLayerConfig): name: str | None = None + @override def make(self) -> 'Flatten': return Flatten(name=self.name) + @override @nn.nowrap def get_output_shape( self, @@ -1565,9 +1885,11 @@ def get_output_shape( constants: types.Constants | None = None, ) -> types.Shape: del constants - return (np.prod(input_shape),) + return (int(np.prod(input_shape)),) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -1581,17 +1903,18 @@ def layer( return x.apply_values_masked(jnp.reshape, [batch_size, time, num_elements]) -class OneHot(types.Stateless): +class OneHot(types.Stateless, spec.OneHot[types.Sequence, types.ShapeDType]): """Computes one-hot vector of the input.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.OneHot.Config): depth: int compute_dtype: types.DType = jnp.float32 name: str | None = None + @override def make(self) -> 'OneHot': - return OneHot(self, name=self.name) + return OneHot(config=self, name=self.name) config: Config @@ -1603,6 +1926,7 @@ def _validate(self, dtype: types.DType): f' {dtype}' ) + @override @nn.nowrap def get_output_shape( self, @@ -1612,6 +1936,7 @@ def get_output_shape( ) -> types.Shape: return tuple(input_shape) + (self.config.depth,) + @override @nn.nowrap def get_output_dtype( self, @@ -1622,7 +1947,9 @@ def get_output_dtype( self._validate(input_dtype) return self.config.compute_dtype + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -1641,11 +1968,13 @@ def layer( ) -class Embedding(types.Stateless): +class Embedding( + types.Stateless, spec.Embedding[types.Sequence, types.ShapeDType] +): """Computes embeddings of integer input codes.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Embedding.Config): """Config for Embedding.""" # Dimensionality of the embedded values. @@ -1665,11 +1994,13 @@ class Config(types.SequenceLayerConfig): name: str | None = None embedding_param_name: str = 'embedding' + @override def make(self) -> 'Embedding': - return Embedding(self, name=self.name) + return Embedding(config=self, name=self.name) config: Config + @override def setup(self): self.embedding = self.param( self.config.embedding_param_name, @@ -1688,6 +2019,7 @@ def _validate(self, dtype: types.DType): f' {dtype}' ) + @override @nn.nowrap def get_output_shape( self, @@ -1695,8 +2027,10 @@ def get_output_shape( *, constants: types.Constants | None = None, ) -> types.Shape: + del constants return tuple(input_shape) + (self.config.dimension,) + @override @nn.nowrap def get_output_dtype( self, @@ -1709,7 +2043,9 @@ def get_output_dtype( return self.config.param_dtype return self.config.compute_dtype + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -1791,7 +2127,7 @@ class Config(types.SequenceLayerConfig): @override def make(self) -> 'EmbeddingTranspose': - return EmbeddingTranspose(self, name=self.name) + return EmbeddingTranspose(config=self, name=self.name) config: Config @@ -1809,6 +2145,7 @@ def get_output_dtype( *, constants: types.Constants | None = None, ) -> types.DType: + assert self.embedding.config is not None return utils.get_promoted_dtype( input_dtype, self.config.param_dtype or self.embedding.config.param_dtype, @@ -1823,6 +2160,8 @@ def get_output_shape( *, constants: types.Constants | None = None, ) -> types.Shape: + del constants + assert self.config.embedding.config is not None if ( not input_shape or input_shape[-1] != self.config.embedding.config.dimension @@ -1836,14 +2175,16 @@ def get_output_shape( @override @types.check_layer @nn.compact + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, + *, training: bool, constants: types.Constants | None = None, ) -> types.Sequence: del training, constants - + assert self.embedding.config is not None if self.config.use_bias: bias_init = utils.shard_initializer( self.config.bias_init, self.config.bias_sharding @@ -1868,11 +2209,15 @@ def layer( return ret -class ExpandDims(types.PreservesType, types.Stateless): +class ExpandDims( + types.PreservesType, + types.Stateless, + spec.ExpandDims[types.Sequence, types.ShapeDType], +): """Applies jnp.expand_dims to the channels dimension of the input.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.ExpandDims.Config): """Configuration for ExpandDims.""" # The axis or axes in the channel shape to expand dims on. @@ -1886,8 +2231,9 @@ def __post_init__(self): if not isinstance(self.axis, int): object.__setattr__(self, 'axis', tuple(self.axis)) + @override def make(self) -> 'ExpandDims': - return ExpandDims(self, name=self.name) + return ExpandDims(config=self, name=self.name) config: Config @@ -1912,6 +2258,7 @@ def _normalize_and_validate_axes( return dims @nn.nowrap + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -1926,6 +2273,8 @@ def get_output_shape( return tuple(output_shape) @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -1938,11 +2287,15 @@ def layer( return x.apply_values_masked(jnp.expand_dims, dims) -class Reshape(types.PreservesType, types.Stateless): +class Reshape( + types.PreservesType, + types.Stateless, + spec.Reshape[types.Sequence, types.ShapeDType], +): """Reshapes the channels dimension of the input.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Reshape.Config): """Configuration for Reshape.""" # The new shape of the channels dimension. Can't contain -1, and must have @@ -1955,8 +2308,9 @@ def __post_init__(self): # Use hashable types for sequences. object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + @override def make(self) -> 'Reshape': - return Reshape(self, name=self.name) + return Reshape(config=self, name=self.name) config: Config @@ -1970,6 +2324,7 @@ def _validate_output_shape(self, input_shape: types.ShapeLike) -> None: ) @nn.nowrap + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -1981,6 +2336,8 @@ def get_output_shape( return tuple(self.config.output_shape) @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2028,8 +2385,9 @@ def __post_init__(self): # Use hashable types for sequences. object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + @override def make(self) -> 'GlobalReshape': - return GlobalReshape(self, name=self.name) + return GlobalReshape(config=self, name=self.name) config: Config @@ -2043,6 +2401,7 @@ def _validate_reshape(self, input_shape: types.ShapeLike) -> None: ) @nn.nowrap + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -2053,14 +2412,18 @@ def get_output_shape( return tuple(self.config.output_shape[1:]) @property + @override def supports_step(self) -> bool: return False @property + @override def receptive_field_per_step(self) -> dict[int, types.ReceptiveField]: return {0: (-np.inf, np.inf)} @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2088,11 +2451,15 @@ def layer( return types.Sequence(out, mask=mask) -class Transpose(types.PreservesType, types.Stateless): +class Transpose( + types.PreservesType, + types.Stateless, + spec.Transpose[types.Sequence, types.ShapeDType], +): """Transposes (i.e., permutes) the channels dimension of the input.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Transpose.Config): """Configuration for Transpose. The usage is the same as that of jax.numpy.transpose. @@ -2118,12 +2485,13 @@ def __post_init__(self): if self.axes is not None: object.__setattr__(self, 'axes', tuple(self.axes)) + @override def make(self) -> 'Transpose': if self.axes is not None and (0 in self.axes or 1 in self.axes): raise ValueError("Can't transpose batch or time dimension.") - return Transpose(self, name=self.name) + return Transpose(config=self, name=self.name) config: Config @@ -2143,6 +2511,7 @@ def _validate_axes(self, input_shape: types.ShapeLike) -> tuple[int, ...]: return tuple(axes) + @override @nn.nowrap def get_output_shape( self, @@ -2154,7 +2523,9 @@ def get_output_shape( axes = self._validate_axes(input_shape) return tuple(input_shape[a - 2] for a in axes) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2170,23 +2541,28 @@ class SwapAxes(Transpose): """Swap two channel axes.""" @dataclasses.dataclass(frozen=True) + # pyrefly: ignore[bad-override] class Config(types.SequenceLayerConfig): axis1: int axis2: int name: str | None = None + @override def make(self) -> 'SwapAxes': axes = [self.axis1, self.axis2] if 0 in axes or 1 in axes: raise ValueError("Can't swap batch or time dimension.") + # pyrefly: ignore[missing-argument] return SwapAxes( - Transpose.Config(axes=axes, name=self.name), name=self.name + typing.cast(typing.Any, Transpose.Config(axes=axes, name=self.name)), + name=self.name, ) @override def _validate_axes(self, input_shape: types.ShapeLike) -> tuple[int, ...]: + assert self.config.axes is not None ndim = 2 + len(input_shape) # ndim including batch and time. axes = [a if a >= 0 else ndim + a for a in self.config.axes] if 0 in axes or 1 in axes: @@ -2205,7 +2581,7 @@ class MoveAxis(Transpose): """Moves one or several channel axes to new locations.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig): # pyrefly: ignore[bad-override] """Config of MoveAxis layer.""" source: int | TypingSequence[int] @@ -2217,24 +2593,22 @@ def __post_init__(self): object.__setattr__(self, 'source', to_tuple(self.source)) object.__setattr__(self, 'destination', to_tuple(self.destination)) + @override def make(self) -> 'MoveAxis': - - if ( - 0 in self.source - or 1 in self.source - or 0 in self.destination - or 1 in self.destination - ): + source = typing.cast(TypingSequence[int], self.source) + destination = typing.cast(TypingSequence[int], self.destination) + if 0 in source or 1 in source or 0 in destination or 1 in destination: raise ValueError("Can't move batch or time dimension.") - if len(self.source) != len(self.destination): + if len(source) != len(destination): raise ValueError( - f'Inconsistent number of elements: {len(self.source)} vs' - f' {len(self.destination)}' + f'Inconsistent number of elements: {len(source)} vs' + f' {len(destination)}' ) - return MoveAxis(self, name=self.name) + return MoveAxis(config=self, name=self.name) + # pyrefly: ignore[bad-override] config: Config @override @@ -2262,12 +2636,15 @@ class Emit(types.PreservesType, types.PreservesShape, types.StatelessEmitting): class Config(types.SequenceLayerConfig): name: str | None = None + @override def make(self) -> 'Emit': - return Emit(self, name=self.name) + return Emit(config=self, name=self.name) config: Config @types.check_layer_with_emits + @override + # pyrefly: ignore[missing-override-decorator] def layer_with_emits( self, x: types.Sequence, @@ -2288,12 +2665,15 @@ class Config(types.SequenceLayerConfig): emit_name: str name: str | None = None + @override def make(self) -> 'NamedEmit': - return NamedEmit(self, name=self.name) + return NamedEmit(config=self, name=self.name) config: Config @types.check_layer_with_emits + @override + # pyrefly: ignore[missing-override-decorator] def layer_with_emits( self, x: types.Sequence, @@ -2302,24 +2682,32 @@ def layer_with_emits( constants: types.Constants | None = None, ) -> tuple[types.Sequence, types.Emits]: return x, {self.config.emit_name: x} + return x, {self.config.emit_name: x} -class Dropout(types.PreservesType, types.StatelessPointwise): +class Dropout( + types.PreservesType, + types.StatelessPointwise, + spec.Dropout[types.Sequence, types.ShapeDType], +): """Computes dropout using Flax RNGs.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Dropout.Config): rate: float = 0.0 broadcast_dims: TypingSequence[int] = () rng_collection: str = 'dropout' name: str | None = None + @override def make(self) -> 'Dropout': - return Dropout(self, name=self.name) + return Dropout(config=self, name=self.name) config: Config + @override @types.check_step + # pyrefly: ignore[missing-override-decorator] def step( self, x: types.Sequence, @@ -2385,7 +2773,9 @@ def apply_dropout(self, x: jax.Array, training: bool) -> jax.Array: return x @nn.compact + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2397,34 +2787,54 @@ def layer( return x.apply_values_masked(self.apply_dropout, training=training) -class Downsample1D(types.PreservesType, types.PreservesShape, types.Stateless): +class Downsample1D( + types.PreservesType, + types.Stateless, + spec.Downsample1D[types.Sequence, types.ShapeDType], +): """A 1D downsampling layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Downsample1D.Config): """Configuration for Downsample1D.""" rate: int name: str | None = None + @override def make(self) -> 'Downsample1D': - return Downsample1D(self, name=self.name) + return Downsample1D(config=self, name=self.name) config: Config @property + @override def block_size(self) -> int: return self.config.rate @property + @override def output_ratio(self) -> fractions.Fraction: return fractions.Fraction(1, self.config.rate) @property + @override def input_latency(self) -> int: return self.config.rate - 1 + @override + @nn.nowrap + def get_output_shape( + self, + input_shape: types.ShapeLike, + *, + constants: types.Constants | None = None, + ) -> types.Shape: + return tuple(input_shape) + + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2439,30 +2849,49 @@ def layer( ) -class Upsample1D(types.PreservesType, types.PreservesShape, types.Stateless): +class Upsample1D( + types.PreservesType, + types.Stateless, + spec.Upsample1D[types.Sequence, types.ShapeDType], +): """A 1D upsampling layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Upsample1D.Config): """Configuration for Upsample1D.""" rate: int name: str | None = None + @override def make(self) -> 'Upsample1D': - return Upsample1D(self, name=self.name) + return Upsample1D(config=self, name=self.name) config: Config @property + @override def output_ratio(self) -> fractions.Fraction: return fractions.Fraction(self.config.rate) @property + @override def receptive_field_per_step(self) -> dict[int, types.ReceptiveField]: return {s: (0, 0) for s in range(self.config.rate)} + @override + @nn.nowrap + def get_output_shape( + self, + input_shape: types.ShapeLike, + *, + constants: types.Constants | None = None, + ) -> types.Shape: + return tuple(input_shape) + + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2491,26 +2920,33 @@ class Config(types.SequenceLayerConfig): def __post_init__(self): object.__setattr__(self, 'rate', utils.normalize_2tuple(self.rate)) + @override def make(self) -> 'Upsample2D': - return Upsample2D(self, name=self.name) + return Upsample2D(config=self, name=self.name) config: Config @property + @override def output_ratio(self) -> fractions.Fraction: - return fractions.Fraction(self.config.rate[0]) + rate = typing.cast(TypingSequence[int], self.config.rate) + return fractions.Fraction(rate[0]) @property + @override def receptive_field_per_step(self) -> dict[int, types.ReceptiveField]: - return {s: (0, 0) for s in range(self.config.rate[0])} + rate = typing.cast(TypingSequence[int], self.config.rate) + return {s: (0, 0) for s in range(rate[0])} @nn.nowrap + @override def get_output_shape( self, input_shape: types.ShapeLike, *, constants: types.Constants | None = None, ) -> types.Shape: + rate = typing.cast(TypingSequence[int], self.config.rate) if len(input_shape) != 2: raise ValueError( 'Upsample2D requires rank 4 input got:' @@ -2518,11 +2954,13 @@ def get_output_shape( ) return ( - input_shape[0] * self.config.rate[1], + input_shape[0] * rate[1], input_shape[1], ) @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2530,28 +2968,36 @@ def layer( training: bool, constants: types.Constants | None = None, ) -> types.Sequence: - values = jnp.repeat(x.values, self.config.rate[0], axis=1) - values = jnp.repeat(values, self.config.rate[1], axis=2) - mask = jnp.repeat(x.mask, self.config.rate[0], axis=1) + rate = typing.cast(TypingSequence[int], self.config.rate) + values = jnp.repeat(x.values, rate[0], axis=1) + values = jnp.repeat(values, rate[1], axis=2) + mask = jnp.repeat(x.mask, rate[0], axis=1) # Upsampling does not change the masked state, so use the type of x to # repack the upsampled values and mask. return type(x)(values, mask) -class MaskInvalid(types.PreservesType, types.StatelessPointwise): +class MaskInvalid( + types.PreservesType, + types.StatelessPointwise, + spec.MaskInvalid[types.Sequence, types.ShapeDType], +): """Masks the input sequence.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.MaskInvalid.Config): name: str | None = None + @override def make(self) -> 'MaskInvalid': - return MaskInvalid(self, name=self.name) + return MaskInvalid(config=self, name=self.name) config: Config + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2563,11 +3009,15 @@ def layer( return x.mask_invalid() -class Logging(types.PreservesType, types.StatelessPointwise): +class Logging( + types.PreservesType, + types.StatelessPointwise, + spec.Logging[types.Sequence, types.ShapeDType], +): """Layer that logs input arguments to get_initial_state, step, and layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Logging.Config): """Configuration for the Logging layer.""" prefix: str = '' @@ -2588,8 +3038,9 @@ class Config(types.SequenceLayerConfig): '\ttraining={training}\n\tconstants={constants}' ) + @override def make(self) -> 'Logging': - return Logging(self) + return Logging(config=self) config: Config @@ -2601,6 +3052,7 @@ def _register_callback(self, format_str: str, **kwargs) -> None: nonjax_kwargs = {'prefix': self.config.prefix} for k, v in list(kwargs.items()): if isinstance(v, jax.ShapeDtypeStruct) or not jax.core.valid_jaxtype(v): + # pyrefly: ignore[bad-typed-dict-key] nonjax_kwargs[k] = v del kwargs[k] # We then set up a callback for the remaining tensor values: @@ -2621,7 +3073,9 @@ def arrays_to_specs(leaf: Any) -> types.ShapeDType | str: kwargs = jax.tree.map(arrays_to_specs, kwargs) logging.info(format_str.format(prefix=self.config.prefix, **kwargs)) + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2637,10 +3091,11 @@ def layer( ) return x + @override def get_initial_state( self, batch_size: int, - input_spec: types.ChannelSpec, + input_spec: types.ShapeDType, *, training: bool, constants: types.Constants | None = None, @@ -2657,6 +3112,8 @@ def get_initial_state( ) @types.check_step + @override + # pyrefly: ignore[missing-override-decorator] def step( self, x: types.Sequence, @@ -2682,12 +3139,15 @@ class Argmax(types.Stateless): class Config(types.SequenceLayerConfig): name: str | None = None + @override def make(self) -> 'Argmax': - return Argmax(self, name=self.name) + return Argmax(config=self, name=self.name) config: Config @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2697,6 +3157,7 @@ def layer( ) -> types.Sequence: return x.apply_values(jnp.argmax, axis=-1) + @override @nn.nowrap def get_output_shape( self, @@ -2706,6 +3167,7 @@ def get_output_shape( ) -> types.Shape: return tuple(input_shape[:-1]) + @override @nn.nowrap def get_output_dtype( self, @@ -2743,12 +3205,14 @@ def __post_init__(self): f'`batch` and `time` are reserved axes labels (got {self.pattern}).' ) + @override def make(self) -> 'EinopsRearrange': - return EinopsRearrange(self, name=self.name) + return EinopsRearrange(config=self, name=self.name) config: Config @property + @override def supports_step(self) -> bool: return True @@ -2759,6 +3223,8 @@ def _get_rearrange_fn(self) -> Callable[[jax.Array], jax.Array]: return functools.partial(einops.rearrange, pattern=pattern, **axes_lengths) @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2771,6 +3237,7 @@ def layer( return x.apply_values(rearrange_fn) @nn.nowrap + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -2779,7 +3246,9 @@ def get_output_shape( ) -> types.Shape: del constants rearrange_fn = self._get_rearrange_fn() - output = jax.eval_shape(rearrange_fn, jnp.zeros((1, 1) + input_shape)) + output = jax.eval_shape( + rearrange_fn, jnp.zeros((1, 1) + tuple(input_shape)) + ) return tuple(output.shape[2:]) @@ -2814,18 +3283,21 @@ def __post_init__(self): f'`batch` is a reserved axes labels (got {self.pattern}).' ) + @override def make(self) -> 'GlobalEinopsRearrange': - return GlobalEinopsRearrange(self, name=self.name) + return GlobalEinopsRearrange(config=self, name=self.name) config: Config @property + @override def supports_step(self) -> bool: return False @property - def receptive_field(self) -> tuple[int | None, int | None]: - return (-np.inf, np.inf) + @override + def receptive_field(self) -> types.ReceptiveField: + return typing.cast(types.ReceptiveField, (-np.inf, np.inf)) def _get_rearrange_fn(self) -> Callable[[jax.Array], jax.Array]: before, after = self.config.pattern.split('->') @@ -2834,6 +3306,8 @@ def _get_rearrange_fn(self) -> Callable[[jax.Array], jax.Array]: return functools.partial(einops.rearrange, pattern=pattern, **axes_lengths) @types.check_layer + @override + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, @@ -2863,6 +3337,7 @@ def layer( return types.Sequence(values, mask) @nn.nowrap + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -2876,12 +3351,16 @@ def get_output_shape( else: time_dim = 1 output = jax.eval_shape( - rearrange_fn, jnp.zeros((1, time_dim) + input_shape) + rearrange_fn, jnp.zeros((1, time_dim) + tuple(input_shape)) ) return tuple(output.shape[2:]) -class Squeeze(types.PreservesType, types.Stateless): +class Squeeze( + types.PreservesType, + types.Stateless, + spec.Squeeze[types.Sequence, types.ShapeDType], +): """This layer squeezes all the depth dimensions of the input. I.e. [batch_size, time, *depth_dims -> [batch_size, time] (where all the @@ -2889,13 +3368,14 @@ class Squeeze(types.PreservesType, types.Stateless): """ @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(spec.Squeeze.Config): """Config of Squeeze.""" axis: int | TypingSequence[int] | None = None name: str | None = None + @override def make(self) -> 'Squeeze': axis = self.axis @@ -2905,7 +3385,7 @@ def make(self) -> 'Squeeze': elif axis is not None and (0 in axis or 1 in axis): raise ValueError('Batch and time (axis=0 or 1) cannot be squeezed.') - return Squeeze(self, name=self.name) + return Squeeze(config=self, name=self.name) config: Config @@ -2920,6 +3400,7 @@ def _validate_axis(self, input_shape: types.ShapeLike) -> tuple[int, ...]: return tuple(axis) + @override @nn.nowrap def get_output_shape( self, @@ -2934,7 +3415,9 @@ def get_output_shape( types.ShapeDType((0, 1) + tuple(input_shape), jnp.float32), ).shape[2:] + @override @types.check_layer + # pyrefly: ignore[missing-override-decorator] def layer( self, x: types.Sequence, diff --git a/sequence_layers/jax/simple_test.py b/sequence_layers/jax/simple_test.py index c9a98c8..321a22e 100644 --- a/sequence_layers/jax/simple_test.py +++ b/sequence_layers/jax/simple_test.py @@ -30,132 +30,20 @@ import jax.experimental.mesh_utils # Required for OSS. import jax.numpy as jnp import numpy as np + from sequence_layers.jax import sharding as sharding_lib from sequence_layers.jax import simple from sequence_layers.jax import test_utils from sequence_layers.jax import types +from sequence_layers.specs import simple_behaviors as spec -class ScaleTest(test_utils.SequenceLayerTest): - - @parameterized.parameters(((2, 13, 5),), ((2, 13, 5, 9),)) - def test_basic(self, shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = simple.Scale.Config(scale=2.0, name='scale').make() - l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) - self.assertEqual(l.name, 'scale') - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - y_expected = x.apply_values(lambda v: v * 2.0) - self.assertSequencesEqual(y, y_expected) - - @parameterized.parameters(((2, 13, 5),), ((2, 13, 9, 5),)) - def test_ndarray(self, shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = simple.Scale.Config( - scale=np.arange(5, dtype=np.float32), name='scale' - ).make() - l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) - self.assertEqual(l.name, 'scale') - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - y_expected = x.apply_values(lambda v: v * np.arange(5, dtype=np.float32)) - self.assertSequencesEqual(y, y_expected) - - def test_broadcast(self): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(2, 3, 5, 1) - l = simple.Scale.Config(scale=np.ones((5, 9))).make() - l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.get_output_shape_for_sequence(x), (5, 9)) - - def test_too_many_dims(self): - x = test_utils.random_sequence(2, 3, 5, 1) - l = simple.Scale.Config(scale=np.ones((5, 5, 5))).make().bind({}) - with self.assertRaises(ValueError): - l.get_output_shape_for_sequence(x) - - with self.assertRaises(ValueError): - l.layer(x, training=False) - - def test_broadcast_failure(self): - x = test_utils.random_sequence(2, 3, 5, 9) - l = simple.Scale.Config(scale=np.ones((5,))).make().bind({}) - with self.assertRaises(ValueError): - l.get_output_shape_for_sequence(x) - - with self.assertRaises(ValueError): - l.layer(x, training=False) +class ScaleTest(test_utils.SequenceLayerTest, spec.ScaleTest): + pass -class AddTest(test_utils.SequenceLayerTest): - - @parameterized.parameters((((2, 13, 5)),), (((2, 13, 5, 9)),)) - def test_add(self, shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = simple.Add.Config(-2.0, name='add').make() - l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) - self.assertEqual(l.name, 'add') - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - y_expected = x.apply_values(lambda v: v - 2.0).mask_invalid() - self.assertSequencesEqual(y, y_expected) - - @parameterized.parameters(((2, 13, 5),), ((2, 13, 9, 5),)) - def test_ndarray(self, shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = simple.Add.Config( - shift=np.arange(5, dtype=np.float32), name='add' - ).make() - l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) - self.assertEqual(l.name, 'add') - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - y_expected = x.apply_values( - lambda v: v + np.arange(5, dtype=np.float32) - ).mask_invalid() - self.assertSequencesEqual(y, y_expected) - - def test_broadcast(self): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(2, 3, 5, 1) - l = simple.Add.Config(shift=np.ones((5, 9))).make() - l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.get_output_shape_for_sequence(x), (5, 9)) - - def test_too_many_dims(self): - x = test_utils.random_sequence(2, 3, 5, 1) - l = simple.Add.Config(shift=np.ones((5, 5, 5))).make().bind({}) - with self.assertRaises(ValueError): - l.get_output_shape_for_sequence(x) - - with self.assertRaises(ValueError): - l.layer(x, training=False) - - def test_broadcast_failure(self): - x = test_utils.random_sequence(2, 3, 5, 9) - l = simple.Add.Config(shift=np.ones((5,))).make().bind({}) - with self.assertRaises(ValueError): - l.get_output_shape_for_sequence(x) - - with self.assertRaises(ValueError): - l.layer(x, training=False) +class AddTest(test_utils.SequenceLayerTest, spec.AddTest): + pass class MinimumTest(test_utils.SequenceLayerTest): @@ -344,34 +232,31 @@ def test_broadcast_failure(self): l.layer(x, training=False) -class GatedUnitTest(test_utils.SequenceLayerTest): +class GatedUnitTest(test_utils.SequenceLayerTest, spec.GatedUnitTest): @parameterized.parameters( itertools.product( - (simple.GatedUnit.Config(None, None), # Bilinear - simple.GatedUnit.Config(None, jax.nn.swish), # SwiGLU - simple.GatedUnit.Config(None, jax.nn.gelu), # GeGLU - simple.GatedUnit.Config(lambda x: x, None), # Bilinear - simple.GatedUnit.Config(jax.nn.swish, jax.nn.tanh), - simple.GatedTanhUnit.Config(), - simple.GatedLinearUnit.Config()), - ((2, 13, 6), (2, 13, 5, 10))) - ) # pyformat: disable - def test_gated_activation(self, layer_config, shape): + ( + simple.GatedUnit.Config(None, None), # Bilinear + simple.GatedUnit.Config(None, jax.nn.swish), # SwiGLU + simple.GatedUnit.Config(None, jax.nn.gelu), # GeGLU + simple.GatedUnit.Config(lambda x: x, None), # Bilinear + simple.GatedUnit.Config(jax.nn.swish, jax.nn.tanh), + simple.GatedTanhUnit.Config(), + simple.GatedLinearUnit.Config(), + ), + ((2, 13, 6), (2, 13, 5, 10)), + ) + ) # pyformat: disable + def test_variables_empty(self, layer_config, shape): key = jax.random.PRNGKey(1234) x = test_utils.random_sequence(*shape) l = layer_config.make() l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual( - l.get_output_shape_for_sequence(x), shape[2:-1] + (shape[-1] // 2,) - ) - self.verify_contract(l, x, training=True) self.assertEmpty(l.variables) -class DropoutTest(test_utils.SequenceLayerTest): +class DropoutTest(test_utils.SequenceLayerTest, spec.DropoutTest): @parameterized.parameters( jnp.float32, jnp.bfloat16, jnp.int32, jnp.int8, jnp.bool @@ -511,28 +396,8 @@ def test_slice_wrongsize(self): l.layer(x, training=False) -class FlattenTest(test_utils.SequenceLayerTest): - - @parameterized.parameters( - (((2, 3, 5)),), (((2, 3, 5, 9)),), (((2, 3, 5, 9, 2)),) - ) - def test_flatten(self, shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = simple.Flatten.Config(name='flatten').make() - l = self.init_and_bind_layer(key, l, x) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - num_elements = np.prod(shape[2:]) - self.assertEqual(l.get_output_shape_for_sequence(x), (num_elements,)) - self.assertEqual(l.name, 'flatten') - - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - - y_expected = x.apply_values(jnp.reshape, shape[:2] + (num_elements,)) - self.assertSequencesEqual(y, y_expected) +class FlattenTest(test_utils.SequenceLayerTest, spec.FlattenTest): + pass class GlobalReshapeTest(test_utils.SequenceLayerTest): @@ -601,30 +466,7 @@ def test_wrong_shape(self): self.init_and_bind_layer(key, l, x) -class ReshapeTest(test_utils.SequenceLayerTest): - - @parameterized.parameters( - ((2, 3, 5), (1, 5, 1)), - ((2, 3, 5, 9), (3, 3, 5)), - ((2, 3, 1), ()), - ((2, 3), (1,)), - ) - def test_reshape(self, shape, output_shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = simple.Reshape.Config(output_shape, name='reshape').make() - l = self.init_and_bind_layer(key, l, x) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), output_shape) - self.assertEqual(l.name, 'reshape') - - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - - y_expected = x.apply_values(jnp.reshape, shape[:2] + output_shape) - self.assertSequencesEqual(y, y_expected) +class ReshapeTest(test_utils.SequenceLayerTest, spec.ReshapeTest): def test_wrong_shape(self): l = simple.Reshape.Config([4], name='reshape').make().bind({}) @@ -637,36 +479,7 @@ def test_wrong_shape(self): l.layer(x, training=False) -class TransposeTest(test_utils.SequenceLayerTest): - - @parameterized.parameters( - ((2, 3, 4, 5), (2, 3), (4, 5)), - ((2, 3, 4, 5, 6), (4, 2, 3), (6, 4, 5)), - ((2, 3, 1, 2, 3), None, (3, 2, 1)), - ((2, 3), tuple(), tuple()), - ((2, 3), None, tuple()), - ) - def test_transpose(self, input_shape, axes, output_shape): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*input_shape) - l = simple.Transpose.Config(axes=axes, name='transpose').make() - l = self.init_and_bind_layer(key, l, x) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), output_shape) - self.assertEqual(l.name, 'transpose') - - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - - if axes is not None: - y_expected = x.apply_values(jnp.transpose, (0, 1) + axes) - else: - axes = (0, 1) + tuple(range(2, x.ndim))[::-1] - y_expected = x.apply_values(jnp.transpose, axes) - - self.assertSequencesEqual(y, y_expected) +class TransposeTest(test_utils.SequenceLayerTest, spec.TransposeTest): @parameterized.parameters( ((2, 3), (2,)), @@ -989,20 +802,20 @@ def layer_vjp_fn( self.assertSequencesEqual(expected_gradients, y_layer_x_grad) -class IdentityTest(test_utils.SequenceLayerTest): +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): + pass + + +class IdentityTest(test_utils.SequenceLayerTest, spec.IdentityTest): @parameterized.parameters((((2, 3, 5)),), (((2, 3, 5, 9)),)) - def test_identity(self, shape): + def test_jax_specifics(self, shape): key = jax.random.PRNGKey(1234) x = test_utils.random_sequence(*shape) l = simple.Identity(name='identity') l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) - self.assertEqual(l.name, 'identity') - self.verify_contract(l, x, training=False) self.assertEmpty(l.variables) @@ -1047,10 +860,10 @@ def test_emit(self): self.assertSequencesEqual(emits['test_emit'], x) -class OneHotTest(test_utils.SequenceLayerTest): +class OneHotTest(test_utils.SequenceLayerTest, spec.OneHotTest): @parameterized.parameters(((1, 2, 3),), ((2, 3, 5, 9),), ((2, 3, 5, 9, 2),)) - def test_one_hot(self, shape): + def test_variables_empty(self, shape): key = jax.random.PRNGKey(1234) depth = 4 l = simple.OneHot.Config(depth, name='one_hot').make() @@ -1103,11 +916,13 @@ def embedding_layer_from_weights( return layer -class EmbeddingTest(test_utils.SequenceLayerTest): +class EmbeddingTest(test_utils.SequenceLayerTest, spec.EmbeddingTest): - @parameterized.parameters(((1, 2, 3),), ((2, 3, 5, 9),), ((2, 3, 5, 9, 2),)) - def test_embedding(self, shape): - key = jax.random.PRNGKey(1234) + def test_embedding(self): + super().test_embedding() + + # JAX-specific variables check + shape = (2, 3, 5, 9) dimension, num_embeddings = 8, 5 l = simple.Embedding.Config( @@ -1116,22 +931,13 @@ def test_embedding(self, shape): x = test_utils.random_sequence( *shape, dtype=jnp.int32, low=0, high=num_embeddings - 1 ) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual( - l.get_output_shape_for_sequence(x), shape[2:] + (dimension,) - ) - self.assertEqual(l.name, 'embedding') - l = self.init_and_bind_layer(key, l, x) + l = self.init_and_bind_layer(jax.random.PRNGKey(1234), l, x) y = self.verify_contract( l, x, training=False, - # Integer tensors have no gradient to test. test_gradients=False, - # Receptive field test is not supported for integers. test_receptive_field=False, ) @@ -1399,9 +1205,7 @@ class SnakeTest(test_utils.SequenceLayerTest): def test_snake(self, shape, separate_beta: bool): key = jax.random.PRNGKey(1234) x = test_utils.random_sequence(*shape) - l = simple.Snake.Config( - separate_beta=separate_beta, name='snake' - ).make() + l = simple.Snake.Config(separate_beta=separate_beta, name='snake').make() l = self.init_and_bind_layer(key, l, x) self.assertEqual(l.block_size, 1) self.assertEqual(l.output_ratio, 1) @@ -1420,9 +1224,7 @@ def test_snake(self, shape, separate_beta: bool): expected_params['params']['beta_log'] = jnp.zeros( x.channel_shape, dtype=jnp.float32 ) - chex.assert_trees_all_equal_shapes_and_dtypes( - variables, expected_params - ) + chex.assert_trees_all_equal_shapes_and_dtypes(variables, expected_params) class AffineTest(test_utils.SequenceLayerTest): @@ -1605,11 +1407,10 @@ def test_broadcast_failure(self): l.layer(x, training=False) -class PointwiseMathTest(test_utils.SequenceLayerTest): +class PointwiseMathTest(test_utils.SequenceLayerTest, spec.PointwiseMathTest): @parameterized.parameters( (simple.Abs.Config(), jnp.abs, (jnp.float32, jnp.complex64), None), - (simple.Elu.Config(), jax.nn.elu, (jnp.float32,), None), (simple.Exp.Config(), jnp.exp, (jnp.float32,), None), (simple.Gelu.Config(), jax.nn.gelu, (jnp.float32,), None), (simple.LeakyRelu.Config(), jax.nn.leaky_relu, (jnp.float32,), None), @@ -1622,14 +1423,10 @@ class PointwiseMathTest(test_utils.SequenceLayerTest): (simple.Log.Config(), jnp.log, (jnp.float32,), None), (simple.Power.Config(2), jnp.square, (jnp.float32,), None), (simple.Power.Config(0.5), jnp.sqrt, (jnp.float32,), None), - (simple.Relu.Config(), jax.nn.relu, (jnp.float32,), None), - (simple.Sigmoid.Config(), jax.nn.sigmoid, (jnp.float32,), None), - (simple.Softmax.Config(), jax.nn.softmax, (jnp.float32,), None), - (simple.Softplus.Config(), jax.nn.softplus, (jnp.float32,), None), - (simple.Swish.Config(), jax.nn.swish, (jnp.float32,), None), - (simple.Tanh.Config(), jnp.tanh, (jnp.float32,), None), ) - def test_pointwise_math(self, config, op, dtypes, expected_params): + def test_jax_specific_pointwise_math( + self, config, op, dtypes, expected_params + ): key = jax.random.PRNGKey(1234) batch_size, time, channels = 2, 10, 4 for dtype in dtypes: @@ -1664,81 +1461,9 @@ def test_pointwise_math(self, config, op, dtypes, expected_params): y_expected = x.apply_values(op).mask_invalid() self.assertSequencesClose(y, y_expected) - @parameterized.parameters( - (simple.Softmax.Config(), jax.nn.softmax, (jnp.float32,), -1), - (simple.Softmax.Config(), jax.nn.softmax, (jnp.float32,), -2), - (simple.Softmax.Config(), jax.nn.softmax, (jnp.float32,), 2), - (simple.Softmax.Config(), jax.nn.softmax, (jnp.float32,), 3), - ) - def test_pointwise_math_axis(self, config, op, dtypes, axis): - key = jax.random.PRNGKey(1234) - batch_size, time, channels, channels2 = 2, 10, 4, 3 - for dtype in dtypes: - x = test_utils.random_sequence( - batch_size, time, channels, channels2, dtype=dtype - ) - l = dataclasses.replace(config, name='test', axis=axis).make() - l = self.init_and_bind_layer(key, l, x) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual( - l.get_output_shape_for_sequence(x), (channels, channels2) - ) - self.assertEqual(l.name, 'test') - y = self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - - y_expected = x.apply_values( - functools.partial(op, axis=axis) - ).mask_invalid() - self.assertSequencesClose(y, y_expected) - - @parameterized.parameters( - (simple.Softmax.Config(), (2, 10, 4), -2), - (simple.Softmax.Config(), (2, 10, 4), -3), - (simple.Softmax.Config(), (2, 10, 4), 0), - (simple.Softmax.Config(), (2, 10, 4), 1), - (simple.Softmax.Config(), (2, 10), -1), - ) - def test_pointwise_math_axis_invalid(self, config, shape, axis): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape) - l = dataclasses.replace(config, name='test', axis=axis).make() - - with self.assertRaises(ValueError): - self.init_and_bind_layer(key, l, x) - -class CastTest(test_utils.SequenceLayerTest): - - @parameterized.parameters( - (((2, 3, 5)), jnp.float16), - (((2, 3, 5, 9)), jnp.int32), - ) - def test_cast(self, shape, target_dtype): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(*shape, dtype=jnp.float32) - l = simple.Cast.Config(target_dtype, name='cast').make() - l = self.init_and_bind_layer(key, l, x) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), shape[2:]) - self.assertEqual(l.name, 'cast') - - test_receptive_field = jnp.issubdtype(target_dtype, jnp.inexact) - y = self.verify_contract( - l, - x, - training=False, - padding_invariance_pad_value=jnp.nan - if target_dtype == jnp.float16 - else 32768, - test_receptive_field=test_receptive_field, - ) - self.assertEmpty(l.variables) - self.assertEqual(y.values.dtype, target_dtype) +class CastTest(test_utils.SequenceLayerTest, spec.CastTest): + pass class ApplyShardingTest(test_utils.SequenceLayerTest): @@ -1781,79 +1506,8 @@ def test_basic(self): # TODO(rryan): Test sharding was applied. -class LambdaTest(test_utils.SequenceLayerTest): - - @parameterized.parameters(True, False) - def test_array_fn(self, mask_required: bool): - def fn(v: jax.Array) -> jax.Array: - if mask_required: - # Change the masked status by adding 1. - v = v + 1.0 - return v.reshape(v.shape + (1,)) > 0.5 - - l = ( - simple.Lambda.Config( - fn, - mask_required=mask_required, - expected_input_spec=types.ShapeDType((5,), jnp.float32), - name='lambda', - ) - .make() - .bind({}) - ) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - # Output spec reflects the changed shape and dtype. - x = test_utils.random_sequence(2, 3, 5) - self.assertEqual(l.get_output_shape_for_sequence(x), (5, 1)) - self.assertEqual(l.get_output_dtype(x.dtype), jnp.bool_) - self.assertEqual(l.name, 'lambda') - y = self.verify_contract( - l, - x, - training=False, - # Receptive field test is not supported for bools. - test_receptive_field=False, - ) - self.assertEmpty(l.variables) - self.assertSequencesClose(y, x.apply_values(fn).mask_invalid()) - - @parameterized.parameters(True, False) - def test_sequence_fn(self, mask_required: bool): - def fn(x: types.Sequence) -> types.Sequence: - if mask_required: - # Change the masked status by adding 1. - x = x.apply_values(lambda v: v + 1.0) - return x.apply_values_masked(lambda v: v.reshape(v.shape + (1,)) > 0.5) - - l = ( - simple.Lambda.Config( - fn, - sequence_input=True, - expected_input_spec=types.ShapeDType((5,), jnp.float32), - name='lambda', - ) - .make() - .bind({}) - ) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - # Output spec reflects the changed shape and dtype. - x = test_utils.random_sequence(2, 3, 5) - self.assertEqual(l.get_output_shape_for_sequence(x), (5, 1)) - self.assertEqual(l.get_output_dtype(x.dtype), jnp.bool_) - self.assertEqual(l.name, 'lambda') - y = self.verify_contract( - l, - x, - training=False, - # Receptive field test is not supported for bools. - test_receptive_field=False, - ) - self.assertEmpty(l.variables) - self.assertSequencesClose(y, fn(x).mask_invalid()) +class LambdaTest(test_utils.SequenceLayerTest, spec.LambdaTest): + """Test behavior of Lambda layer.""" def test_invalid_input(self): """Input that does not match expected_input_spec raises ValueError.""" @@ -1895,23 +1549,19 @@ def test_invalid_fn(self): l.layer(x, training=False) -class CheckpointNameTest(test_utils.SequenceLayerTest): +class CheckpointNameTest(test_utils.SequenceLayerTest, spec.CheckpointNameTest): + """Test behavior of CheckpointName layer.""" def test_basic(self): - key = jax.random.PRNGKey(1234) + super().test_basic() + x = test_utils.random_sequence(2, 3, 5) + key = jax.random.PRNGKey(1234) l = simple.CheckpointName.Config( checkpoint_name='test', name='checkpoint_name' ).make() l = self.init_and_bind_layer(key, l, x) - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), (5,)) - self.assertEqual(l.name, 'checkpoint_name') - self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - policy = jax.checkpoint_policies.save_only_these_names('test') @functools.partial(jax.checkpoint, policy=policy) @@ -1928,42 +1578,12 @@ def f(x: types.Sequence) -> types.Sequence: ) -class Downsample1DTest(test_utils.SequenceLayerTest): - - @parameterized.parameters(((2, 3, 5), 2), ((2, 3, 5, 9), 3)) - def test_downsample1d(self, shape, rate): - l = simple.Downsample1D.Config(rate, name='downsample_1d').make().bind({}) +class Downsample1DTest(test_utils.SequenceLayerTest, spec.Downsample1DTest): + pass - self.assertEqual(l.block_size, rate) - self.assertEqual(1 / l.output_ratio, rate) - self.assertTrue(l.supports_step) - self.assertEqual(l.name, 'downsample_1d') - self.assertEmpty(l.variables) - - x = test_utils.random_sequence(*shape) - self.assertEqual(l.get_output_shape_for_sequence(x), x.channel_shape) - y = self.verify_contract(l, x, training=False) - self.assertAllEqual(x.values[:, ::rate], y.values) - self.assertAllEqual(x.mask[:, ::rate], y.mask) - - -class Upsample1DTest(test_utils.SequenceLayerTest): - - @parameterized.parameters(((2, 3, 5), 2), ((2, 3, 5, 9), 3)) - def test_upsample1d(self, shape, rate): - l = simple.Upsample1D.Config(rate, name='upsample_1d').make().bind({}) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, rate) - self.assertTrue(l.supports_step) - self.assertEqual(l.name, 'upsample_1d') - self.assertEmpty(l.variables) - x = test_utils.random_sequence(*shape) - self.assertEqual(l.get_output_shape_for_sequence(x), x.channel_shape) - y = self.verify_contract(l, x, training=False) - for i in range(rate): - self.assertAllEqual(x.values, y.values[:, i::rate]) +class Upsample1DTest(test_utils.SequenceLayerTest, spec.Upsample1DTest): + pass class Upsample2DTest(test_utils.SequenceLayerTest): @@ -1989,24 +1609,8 @@ def test_upsample2d(self, shape, rate): self.assertAllEqual(x.values, y.values[:, i :: rate[0], j :: rate[1], :]) -class MaskInvalidTest(test_utils.SequenceLayerTest): - - def test_basic(self): - x = test_utils.random_sequence(2, 15, 5) - l = simple.MaskInvalid.Config(name='mask_invalid').make().bind({}) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual(l.get_output_shape_for_sequence(x), (5,)) - self.assertEqual(l.name, 'mask_invalid') - self.verify_contract(l, x, training=False) - self.assertEmpty(l.variables) - - x = x.mask_invalid(np.nan) - self.assertIsInstance(x, types.Sequence) - y = l.layer(x, training=False) - self.assertIsInstance(y, types.MaskedSequence) - self.assertSequencesEqual(x.mask_invalid(), y) +class MaskInvalidTest(test_utils.SequenceLayerTest, spec.MaskInvalidTest): + pass class ReduceTest(test_utils.SequenceLayerTest): @@ -2089,106 +1693,8 @@ def test_reduce_invalid_axis(self, layer_config, axis): self.init_and_bind_layer(key, l, x) -class Has: - """A simple `HAS(v)` matcher that tests whether something has `v` in it.""" - - def __init__(self, value): - self._v = value - - def __eq__(self, o): - return self._v in o - - def __ne__(self, o): - return not self == o - - def __repr__(self): - return '' % self._v - - -class Not: - """Negates a matcher.""" - - def __init__(self, matcher): - self._matcher = matcher - - def __eq__(self, o): - return self._matcher != o - - def __ne__(self, o): - return not self == o - - def __repr__(self): - return '' % self._matcher - - -class LoggingTest(test_utils.SequenceLayerTest): - - @mock.patch.object(logging, 'info', wraps=logging.info) - def test_logs_tensors(self, mock_logger): - x = types.Sequence.from_values(jnp.asarray([[1.414, 2, 3, 4]])) - state = types.Sequence.from_values(jnp.asarray([[1, 2.718, 3, 4]])) - training = False - constants = { - 'foo': jnp.asarray([[1, 2, 3.14, 4]]), - 'bar': np.asarray([[1, 2, 3, 4.2]]), - } - - with self.subTest('prefix'): - l = simple.Logging.Config(prefix='test string').make().bind({}) - l.layer(x, training=training, constants=constants) - mock_logger.assert_called_with(Has('test string')) - - with self.subTest('specs_only'): - l = simple.Logging.Config(dump_tensors=False).make().bind({}) - with self.subTest('layer'): - l.layer(x, training=training, constants=constants) - mock_logger.assert_called_with(Not(Has('1.414'))) - mock_logger.assert_called_with(Not(Has('3.14'))) - mock_logger.assert_called_with(Not(Has('4.2'))) - mock_logger.assert_called_with(Has('(1, 4)')) - mock_logger.assert_called_with(Has('float32')) - with self.subTest('get_initial_state'): - l.get_initial_state( - batch_size=x.shape[0], - input_spec=x.channel_spec, - training=training, - constants=constants, - ) - mock_logger.assert_called_with(Not(Has('3.14'))) - mock_logger.assert_called_with(Not(Has('4.2'))) - mock_logger.assert_called_with(Has('(1, 4)')) - mock_logger.assert_called_with(Has('float32')) - with self.subTest('step'): - l.step(x, state, training=training, constants=constants) - mock_logger.assert_called_with(Not(Has('1.414'))) - mock_logger.assert_called_with(Not(Has('2.718'))) - mock_logger.assert_called_with(Not(Has('3.14'))) - mock_logger.assert_called_with(Not(Has('4.2'))) - mock_logger.assert_called_with(Has('(1, 4)')) - mock_logger.assert_called_with(Has('float32')) - - with self.subTest('dumps_tensors'): - l = simple.Logging.Config(dump_tensors=True).make().bind({}) - with self.subTest('layer'): - l.layer(x, training=training, constants=constants) - mock_logger.assert_called_with(Has('1.414')) - mock_logger.assert_called_with(Has('3.14')) - mock_logger.assert_called_with(Has('4.2')) - with self.subTest('get_initial_state'): - l.get_initial_state( - batch_size=x.shape[0], - input_spec=x.channel_spec, - training=training, - constants=constants, - ) - mock_logger.assert_called_with(Has('3.14')) - mock_logger.assert_called_with(Has('4.2')) - with self.subTest('step'): - l.step(x, state, training=training, constants=constants) - mock_logger.assert_called_with(Has('1.414')) - mock_logger.assert_called_with(Has('2.718')) - mock_logger.assert_called_with(Has('3.14')) - mock_logger.assert_called_with(Has('4.2')) +class LoggingTest(test_utils.SequenceLayerTest, spec.LoggingTest): + """Test behavior of Logging layer.""" class ArgmaxTest(test_utils.SequenceLayerTest): @@ -2222,63 +1728,7 @@ def test_argmax(self, input_array: jnp.ndarray): self.assertAllEqual(y.values, jnp.array([[2], [0]])) -class SqueezeTest(test_utils.SequenceLayerTest): - - @parameterized.named_parameters( - dict( - testcase_name='float_input', - input_array=np.array( - [[[3]]], - dtype=np.float32, - ), - expected_output=np.array([[3]]), - ), - dict( - testcase_name='int_input', - input_array=np.array( - [[[3]]], - dtype=np.int32, - ), - expected_output=np.array([[3]], dtype=np.int32), - ), - dict( - testcase_name='no_op_input', - input_array=np.array( - [[3]], - dtype=np.float32, - ), - expected_output=np.array([[3]]), - ), - dict( - testcase_name='input_with_extra_dims', - input_array=np.array( - [[[[[3], [4]]]]], - dtype=np.float32, - ), - expected_output=np.array([[[3, 4]]]), - ), - ) - def test_squeeze( - self, input_array: jnp.ndarray, expected_output: jnp.ndarray - ): - key = jax.random.PRNGKey(1234) - x = types.Sequence.from_values(input_array) - l = simple.Squeeze.Config(name='squeeze').make() - l = self.init_and_bind_layer(key, l, x) - - _ = l.layer(x, training=False) - - self.assertEqual(l.block_size, 1) - self.assertEqual(l.output_ratio, 1) - self.assertEqual( - l.get_output_shape_for_sequence(x), expected_output.shape[2:] - ) - self.assertEqual(l.name, 'squeeze') - test_receptive_field = jnp.issubdtype(input_array.dtype, jnp.inexact) - self.verify_contract( - l, x, training=False, test_receptive_field=test_receptive_field - ) - self.assertEmpty(l.variables) +class SqueezeTest(test_utils.SequenceLayerTest, spec.SqueezeTest): @parameterized.parameters( ((2, 3, 1, 1, 1), 2, (1, 1)), diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index b336f15..c81aff3 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -716,6 +716,10 @@ def setUp(self): random.seed(123456789) np.random.seed(123456789) + @override + def get_variables(self, layer): + return layer.variables + def init_and_bind_layer( self, key: jax.Array, @@ -765,10 +769,12 @@ def randomize_weights_fn(variables): return layer.bind(variables) - def init_layer(self, layer, x, **kwargs): + def init_layer(self, layer, x, bind_only=False): """Initialize and bind variables for JAX.""" + if bind_only: + return layer.bind({}) key = jax.random.PRNGKey(1234) - return self.init_and_bind_layer(key, layer, x, **kwargs) + return self.init_and_bind_layer(key, layer, x) def verify_masked(self, x: types.Sequence): """Asserts all invalid timesteps in x have values masked to zero.""" diff --git a/sequence_layers/jax/types.py b/sequence_layers/jax/types.py index 1c82b6c..edb714d 100644 --- a/sequence_layers/jax/types.py +++ b/sequence_layers/jax/types.py @@ -20,19 +20,8 @@ import functools import math import typing -from typing import ( - Any, - Callable, - Concatenate, - Generic, - Iterable, - Literal, - MutableMapping, - override, - ParamSpec, - Protocol, - Self, -) +from typing import (Any, Callable, Concatenate, Generic, Iterable, Literal, + MutableMapping, override, ParamSpec, Protocol, Self) from typing import Sequence as TypingSequence from typing import TypeVar @@ -49,6 +38,31 @@ from sequence_layers.jax import typing as jt from sequence_layers.specs import types as spec + +def _to_tuple(x: complex | list[Any]) -> complex | tuple[Any, ...]: + """Replaces lists in a pytree of complex with tuples.""" + if isinstance(x, list): + return tuple(_to_tuple(i) for i in x) + else: + return x + + +@dataclasses.dataclass(frozen=True) +class HashableArray: + """Hashable multidimensional array of tuples.""" + + data: complex | tuple[Any, ...] + dtype: np.dtype + + @classmethod + def from_array(cls, x: np.ndarray) -> 'HashableArray': + x = np.asarray(x) + return HashableArray(_to_tuple(x.tolist()), x.dtype) + + def to_array(self) -> np.ndarray: + return np.asarray(self.data, dtype=self.dtype) + + __all__ = ( # go/keep-sorted start 'ArrayLike', @@ -59,6 +73,7 @@ 'Emits', 'Emitting', 'ExpandedMaskT', + 'HashableArray', 'MASK_DTYPE', 'MaskT', 'MaskedSequence', @@ -107,6 +122,7 @@ # False indicates it is invalid. MaskT = TypeVar('MaskT', bound=jt.Bool[jt.ArrayT, 'B T']) + # An integer batched lengths array. LengthsT = TypeVar('LengthsT', bound=jt.Int[jt.ArrayT, 'B']) @@ -1541,7 +1557,7 @@ def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: @property @override - def mask_required(self): + def mask_required(self) -> bool: """Returns true if fn can change the sequence's masked state. If fn(0) -> 0, then mask_required() is False. diff --git a/sequence_layers/jax/types_test.py b/sequence_layers/jax/types_test.py index 89df70e..2efd604 100644 --- a/sequence_layers/jax/types_test.py +++ b/sequence_layers/jax/types_test.py @@ -195,7 +195,9 @@ def fn(x: types.Sequence) -> types.Sequence: self.assertSequencesEqual(y, x) -class SequenceLayerConfigTest(test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest): +class SequenceLayerConfigTest( + test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest +): def test_copy_raises_on_mutable_attribute(self): @@ -256,11 +258,19 @@ class EmittingTest(test_utils.SequenceLayerTest, spec.EmittingTest): pass -class StatelessEmittingTest(test_utils.SequenceLayerTest, spec.StatelessEmittingTest): +class StatelessEmittingTest( + test_utils.SequenceLayerTest, spec.StatelessEmittingTest +): pass -class StatelessPointwiseFunctorTest(test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest): +class StatelessPointwiseFunctorTest( + test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest +): + pass + + +class HashableArrayTest(test_utils.SequenceLayerTest, spec.HashableArrayTest): pass diff --git a/sequence_layers/jax/typing.py b/sequence_layers/jax/typing.py index a83a2ba..8bd68d4 100644 --- a/sequence_layers/jax/typing.py +++ b/sequence_layers/jax/typing.py @@ -30,7 +30,7 @@ import typeguard if TYPE_CHECKING: - ArrayT = jax.Array | np.ndarray + ArrayT = jax.Array | np.ndarray | jax.ShapeDtypeStruct else: class _MetaArrayT(type): From 412cb84bd7695a21fddfbf75872b6496da8fc089 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 17 Apr 2026 00:39:50 -0700 Subject: [PATCH 4/5] refactor(mlx): Standardize simple layers with spec behaviors --- sequence_layers/mlx/__init__.py | 98 +- sequence_layers/mlx/backend.py | 52 + sequence_layers/mlx/backend_test.py | 4 + sequence_layers/mlx/simple.py | 1422 +++++++++++++++++++-------- sequence_layers/mlx/simple_test.py | 578 +++-------- sequence_layers/mlx/test_utils.py | 62 +- sequence_layers/mlx/types.py | 69 +- sequence_layers/mlx/types_test.py | 4 + 8 files changed, 1431 insertions(+), 858 deletions(-) diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 4c861f5..6a17923 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -13,10 +13,96 @@ # limitations under the License. """Sequence layers in MLX.""" -# (re-export the names for typechecking) -from . import backend as backend -from . import types as types -from . import test_utils as test_utils -from .test_utils import SequenceLayerTest +from . import backend +from . import simple +from . import types +# CRITICAL: Do NOT use wildcard imports (e.g., `from .simple import *`) here. +# Pyrefly (our static analysis tool) has a known limitation with cross-module +# resolution of diamond inheritance chains. When wildcard imports are used to +# re-export classes from `simple.py` (which combine `types` and `spec` bases), +# Pyrefly fails to resolve the concrete method implementations in `mlx/types.py` +# and flags all instances as abstract (`bad-instantiation` false positives). +# +# Explicit imports (e.g., `from .simple import Relu`) DO NOT trigger this issue. +# If you need to expose specific layers at the package level, import them +# explicitly instead of using a star import. +from .simple import Abs +from .simple import Add +from .simple import Cast +from .simple import CheckpointName +from .simple import Downsample1D +from .simple import Dropout +from .simple import Elu +from .simple import Embedding +from .simple import Exp +from .simple import ExpandDims +from .simple import Flatten +from .simple import GatedLinearUnit +from .simple import GatedTanhUnit +from .simple import GatedUnit +from .simple import Gelu +from .simple import Identity +from .simple import Lambda +from .simple import LeakyRelu +from .simple import Log +from .simple import Logging +from .simple import MaskInvalid +from .simple import OneHot +from .simple import Relu +from .simple import Reshape +from .simple import Scale +from .simple import Sigmoid +from .simple import Softmax +from .simple import Softplus +from .simple import Squeeze +from .simple import Swish +from .simple import Tanh +from .simple import Transpose +from .simple import Upsample1D +from .types import MaskedSequence +from .types import Sequence +from .types import SequenceLayer +from .types import SequenceLayerConfig -from sequence_layers.mlx.types import * +__all__ = [ + 'backend', + 'types', + 'simple', + 'Sequence', + 'MaskedSequence', + 'SequenceLayer', + 'SequenceLayerConfig', + 'Identity', + 'Relu', + 'Gelu', + 'Abs', + 'Exp', + 'Log', + 'Swish', + 'Tanh', + 'Sigmoid', + 'LeakyRelu', + 'Elu', + 'Softmax', + 'Softplus', + 'Cast', + 'Scale', + 'Add', + 'MaskInvalid', + 'GatedUnit', + 'GatedLinearUnit', + 'GatedTanhUnit', + 'Flatten', + 'Reshape', + 'ExpandDims', + 'Squeeze', + 'Transpose', + 'OneHot', + 'Embedding', + 'Dropout', + 'Downsample1D', + 'Upsample1D', + 'CheckpointName', + 'Lambda', + 'Logging', +] diff --git a/sequence_layers/mlx/backend.py b/sequence_layers/mlx/backend.py index 4dd649c..d631734 100644 --- a/sequence_layers/mlx/backend.py +++ b/sequence_layers/mlx/backend.py @@ -3,6 +3,7 @@ from typing import override import mlx.core as mx +import mlx.nn as nn_mlx from sequence_layers.specs import backend as spec from sequence_layers.specs import types as types_spec @@ -27,5 +28,56 @@ def zeros(self, shape, dtype=None) -> types_spec.Array: def concatenate(self, arrays, axis=0) -> types_spec.Array: return mx.concatenate(arrays, axis=axis) + @override + def abs(self, x) -> types_spec.Array: + return mx.abs(x) + + @override + def exp(self, x) -> types_spec.Array: + return mx.exp(x) + + @override + def log(self, x) -> types_spec.Array: + return mx.log(x) + xp: spec.xp = BackendWrapper() + + +class NNWrapper(spec.nn): + """Wrapper around MLX activations to match backend protocol.""" + + @override + def relu(self, x: types_spec.Array) -> types_spec.Array: + return nn_mlx.relu(x) + + @override + def sigmoid(self, x: types_spec.Array) -> types_spec.Array: + return mx.sigmoid(x) + + @override + def tanh(self, x: types_spec.Array) -> types_spec.Array: + return mx.tanh(x) + + @override + def swish(self, x: types_spec.Array) -> types_spec.Array: + return nn_mlx.silu(x) + + @override + def gelu(self, x: types_spec.Array) -> types_spec.Array: + return nn_mlx.gelu(x) + + @override + def elu(self, x: types_spec.Array) -> types_spec.Array: + return nn_mlx.elu(x) + + @override + def softplus(self, x: types_spec.Array) -> types_spec.Array: + return nn_mlx.softplus(x) + + @override + def softmax(self, x: types_spec.Array, axis: int = -1) -> types_spec.Array: + return mx.softmax(x, axis=axis) + + +nn: spec.nn = NNWrapper() diff --git a/sequence_layers/mlx/backend_test.py b/sequence_layers/mlx/backend_test.py index 4c8ab5f..c5b8138 100644 --- a/sequence_layers/mlx/backend_test.py +++ b/sequence_layers/mlx/backend_test.py @@ -10,5 +10,9 @@ class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): pass +class BackendNNTest(test_utils.SequenceLayerTest, spec.BackendNNTest): + """Tests for MLX backend.nn operations.""" + + if __name__ == '__main__': absltest.main() diff --git a/sequence_layers/mlx/simple.py b/sequence_layers/mlx/simple.py index 830df0b..140cc34 100644 --- a/sequence_layers/mlx/simple.py +++ b/sequence_layers/mlx/simple.py @@ -1,21 +1,70 @@ """Simple sequence layers for MLX.""" import dataclasses +from fractions import Fraction import math +from typing import Any, Callable, override -from typing import Callable - +from absl import logging +from mlx import nn import mlx.core as mx -import mlx.nn as nn import numpy as np -from sequence_layers.mlx import basic_types as bt -from sequence_layers.mlx import init_mapping from sequence_layers.mlx import types -from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig +from sequence_layers.specs import simple as spec + +Sequence = types.Sequence +MaskedSequence = types.MaskedSequence +ShapeDType = types.ShapeDType + -Sequence = bt.Sequence -MaskedSequence = bt.MaskedSequence +def _to_tuple(x: complex | list[Any]) -> complex | tuple[Any, ...]: + """Converts a nested list to a nested tuple.""" + if isinstance(x, list): + return tuple(_to_tuple(item) for item in x) + return x + + +@dataclasses.dataclass(frozen=True) +class HashableArray(spec.HashableArray): + """Hashable multidimensional array of tuples.""" + + data: complex | tuple[Any, ...] + dtype: np.dtype + + @classmethod + def from_array(cls, x: np.ndarray) -> 'HashableArray': + """Creates a HashableArray from a numpy array.""" + x = np.asarray(x) + return HashableArray(_to_tuple(x.tolist()), x.dtype) + + @override + def to_array(self) -> np.ndarray: + return np.asarray(self.data, dtype=self.dtype) + + +def _to_mx_dtype(dtype: Any) -> mx.Dtype | None: + """Converts various dtype representations to MLX DType.""" + if dtype is None: + return None + if isinstance(dtype, str): + if dtype == 'float32': + return mx.float32 + if dtype == 'float16': + return mx.float16 + if dtype == 'int32': + return mx.int32 + if dtype == 'bool': + return mx.bool_ + if dtype == np.float32: + return mx.float32 + if dtype == np.float16: + return mx.float16 + if dtype == np.int32: + return mx.int32 + if dtype in (np.bool_, bool): + return mx.bool_ + return dtype # --------------------------------------------------------------------------- @@ -23,164 +72,400 @@ # --------------------------------------------------------------------------- -class Identity(types.PreservesType, types.StatelessPointwise): +class Identity( + types.PreservesType, + types.StatelessPointwise, + spec.Identity[types.Sequence, types.ShapeDType], +): """Identity pass-through of the input.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.Identity.Config): + """Configuration for Identity layer.""" + name: str | None = None + @override def make(self) -> 'Identity': - return Identity.from_config(self) + """Creates the Identity layer.""" + return Identity(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @override @types.check_layer - def layer(self, x, *, constants=None): + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Returns the input sequence unchanged.""" return x - @classmethod - def from_config(cls, config): - return cls() - # --------------------------------------------------------------------------- # Activation layers # --------------------------------------------------------------------------- -class Relu(types.PreservesType, types.StatelessPointwiseFunctor): +class Relu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Relu[types.Sequence, types.ShapeDType], +): """A Relu layer.""" + @dataclasses.dataclass(frozen=True) + class Config(spec.Relu.Config): + """Configuration for Relu layer.""" + + name: str | None = None + + @override + def make(self) -> 'Relu': + """Creates the Relu layer.""" + return Relu(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using ReLU.""" return nn.relu(values), mask - @classmethod - def from_config(cls, config): - return cls() - -class Gelu(types.PreservesType, types.StatelessPointwiseFunctor): +class Gelu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Gelu[types.Sequence, types.ShapeDType], +): """A Gelu layer.""" + @dataclasses.dataclass(frozen=True) + class Config(spec.Gelu.Config): + """Configuration for Gelu layer.""" + + name: str | None = None + + @override + def make(self) -> 'Gelu': + """Creates the Gelu layer.""" + return Gelu(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using GELU.""" return nn.gelu(values), mask - @classmethod - def from_config(cls, config): - return cls() +class Abs( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Abs[types.Sequence, types.ShapeDType], +): + """Absolute value layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(spec.Abs.Config): + """Configuration for Abs layer.""" + + name: str | None = None + + @override + def make(self) -> 'Abs': + """Creates the Abs layer.""" + return Abs(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + + @property + @override + def mask_required(self): + return False + + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using absolute value.""" + return mx.abs(values), mask + + +class Exp( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Exp[types.Sequence, types.ShapeDType], +): + """Exponential layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(spec.Exp.Config): + """Configuration for Exp layer.""" + + name: str | None = None + + @override + def make(self) -> 'Exp': + """Creates the Exp layer.""" + return Exp(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config -class Swish(types.PreservesType, types.StatelessPointwiseFunctor): + @property + @override + def mask_required(self): + return False + + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using exponential.""" + return mx.exp(values), mask + + +class Log( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Log[types.Sequence, types.ShapeDType], +): + """Logarithm layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(spec.Log.Config): + """Configuration for Log layer.""" + + name: str | None = None + + @override + def make(self) -> 'Log': + """Creates the Log layer.""" + return Log(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + + @property + @override + def mask_required(self): + return False + + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using natural logarithm.""" + return mx.log(values), mask + + +class Swish( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Swish[types.Sequence, types.ShapeDType], +): """A Swish/SiLU layer.""" + @dataclasses.dataclass(frozen=True) + class Config(spec.Swish.Config): + """Configuration for Swish layer.""" + + name: str | None = None + + @override + def make(self) -> 'Swish': + """Creates the Swish layer.""" + return Swish(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using Swish (SiLU).""" return nn.silu(values), mask - @classmethod - def from_config(cls, config): - return cls() - -class Tanh(types.PreservesType, types.StatelessPointwiseFunctor): +class Tanh( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Tanh[types.Sequence, types.ShapeDType], +): """A tanh layer.""" + @dataclasses.dataclass(frozen=True) + class Config(spec.Tanh.Config): + """Configuration for Tanh layer.""" + + name: str | None = None + + @override + def make(self) -> 'Tanh': + """Creates the Tanh layer.""" + return Tanh(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using hyperbolic tangent.""" return mx.tanh(values), mask - @classmethod - def from_config(cls, config): - return cls() - -class Sigmoid(types.PreservesType, types.StatelessPointwiseFunctor): +class Sigmoid( + types.PreservesType, types.StatelessPointwiseFunctor, spec.Sigmoid +): """A sigmoid layer.""" + @dataclasses.dataclass(frozen=True) + class Config(spec.Sigmoid.Config): + """Configuration for Sigmoid layer.""" + + name: str | None = None + + @override + def make(self) -> 'Sigmoid': + """Creates the Sigmoid layer.""" + return Sigmoid(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using Sigmoid.""" return mx.sigmoid(values), mask - @classmethod - def from_config(cls, config): - return cls() - -class LeakyRelu(types.PreservesType, types.StatelessPointwiseFunctor): +class LeakyRelu( + types.PreservesType, types.StatelessPointwiseFunctor, spec.LeakyRelu +): """A Leaky Relu layer.""" - def __init__(self, negative_slope=0.01): + @dataclasses.dataclass(frozen=True) + class Config(spec.LeakyRelu.Config): + """Configuration for LeakyRelu layer.""" + + negative_slope: float = 0.01 + name: str | None = None + + @override + def make(self) -> 'LeakyRelu': + """Creates the LeakyRelu layer.""" + return LeakyRelu(self) + + def __init__(self, config: Config): super().__init__() - self._negative_slope = negative_slope + self.config = config @property + @override def mask_required(self): return False - def fn(self, values, mask): - return nn.leaky_relu(values, self._negative_slope), mask - - @classmethod - def from_config(cls, config): - return cls(negative_slope=config.negative_slope) + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using Leaky ReLU.""" + return nn.leaky_relu(values, self.config.negative_slope), mask -class Elu(types.PreservesType, types.StatelessPointwiseFunctor): +class Elu( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.Elu[types.Sequence, types.ShapeDType], +): """An ELU activation layer.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.Elu.Config): + """Configuration for Elu layer.""" + alpha: complex = 1.0 name: str | None = None + @override def make(self) -> 'Elu': - return Elu.from_config(self) + """Creates the Elu layer.""" + return Elu(self) - def __init__(self, alpha=1.0): + def __init__(self, config: Config): super().__init__() - self._alpha = alpha + self.config = config @property + @override def mask_required(self): return False - def fn(self, values, mask): - return nn.elu(values, self._alpha), mask - - @classmethod - def from_config(cls, config): - return cls(alpha=config.alpha) + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using ELU.""" + return nn.elu(values, self.config.alpha), mask -class Softmax(types.PreservesType, types.StatelessPointwiseFunctor): +class Softmax( + types.PreservesType, types.StatelessPointwiseFunctor, spec.Softmax +): """A softmax layer.""" - def __init__(self, axis=-1): + @dataclasses.dataclass(frozen=True) + class Config(spec.Softmax.Config): + """Configuration for Softmax layer.""" + + axis: int = -1 + name: str | None = None + + @override + def make(self) -> 'Softmax': + """Creates the Softmax layer.""" + return Softmax(self) + + def __init__(self, config: Config): super().__init__() - self._axis = axis + self.config = config @property + @override def mask_required(self): return False - def fn(self, values, mask): - axis = self._axis + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using Softmax.""" + axis = self.config.axis if (axis if axis >= 0 else values.ndim + axis) < 2: raise ValueError( 'The softmax cannot be applied on the batch or time' @@ -188,123 +473,206 @@ def fn(self, values, mask): ) return mx.softmax(values, axis=axis), mask - @classmethod - def from_config(cls, config): - return cls(axis=config.axis) - -class Softplus(types.PreservesType, types.StatelessPointwiseFunctor): +class Softplus( + types.PreservesType, types.StatelessPointwiseFunctor, spec.Softplus +): """A softplus layer.""" + @dataclasses.dataclass(frozen=True) + class Config(spec.Softplus.Config): + """Configuration for Softplus layer.""" + + name: str | None = None + + @override + def make(self) -> 'Softplus': + """Creates the Softplus layer.""" + return Softplus(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Transforms each scalar in values independently using Softplus.""" return nn.softplus(values), mask - @classmethod - def from_config(cls, config): - return cls() - # --------------------------------------------------------------------------- # Value manipulation # --------------------------------------------------------------------------- -class Cast(types.StatelessPointwiseFunctor): +class Cast( + types.StatelessPointwiseFunctor, spec.Cast[types.Sequence, types.ShapeDType] +): """Cast input values to the specified type.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.Cast.Config): + """Configuration for Cast layer.""" + dtype: object = mx.float32 name: str | None = None + @override def make(self) -> 'Cast': - return Cast.from_config(self) + return Cast(self) - def __init__(self, dtype): + def __init__(self, config: Config): super().__init__() - self._dtype = dtype + self.config = config + self._dtype = _to_mx_dtype(config.dtype) @property + @override def mask_required(self): return False - def fn(self, values, mask): - return values.astype(self._dtype), mask + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Casts input values to the specified type.""" + return values.astype(self._dtype), mask # type: ignore - def get_output_dtype(self, input_dtype, *, constants=None): + @override + def get_output_dtype(self, input_dtype, *, constants=None) -> mx.Dtype: + assert self._dtype is not None return self._dtype - @classmethod - def from_config(cls, config): - from sequence_layers.mlx.init_mapping import _to_mx_dtype - - return cls(dtype=_to_mx_dtype(config.dtype)) - -class Scale(types.PreservesType, types.StatelessPointwise): +class Scale( + types.PreservesType, + types.StatelessPointwise, + spec.Scale[types.Sequence, types.ShapeDType], +): """Scales the input by a provided constant or array.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): - scale: object = 1.0 + class Config(spec.Scale.Config): + """Configuration for Scale layer.""" + + scale: complex | np.ndarray | types.HashableArray = 1.0 name: str | None = None + def __post_init__(self): + object.__setattr__( + self, 'scale', types.HashableArray.from_array(self.scale) + ) + + @override def make(self) -> 'Scale': - return Scale.from_config(self) + """Creates the Scale layer.""" + return Scale(self) - def __init__(self, scale): + def __init__(self, config: Config): super().__init__() - if isinstance(scale, (int, float, complex)): - self._scale = scale - else: - self._scale = mx.array(np.asarray(scale)) + self.config = config + assert isinstance(config.scale, types.HashableArray) + self._scale = config.scale.to_array() - @types.check_layer - def layer(self, x, *, constants=None): - s = self._scale - if isinstance(s, mx.array): - s = s.astype(x.dtype) - return x.apply_values_masked(lambda v: v * s) + @override + def get_output_shape( + self, + input_shape: types.ShapeLike, + *, + constants: types.Constants | None = None, + ) -> types.Shape: + del constants + s_shape = ( + () + if isinstance(self._scale, (int, float, complex)) + else self._scale.shape + ) + if len(s_shape) > len(input_shape): + raise ValueError( + f'Scale parameter has too many dimensions ({len(s_shape)}) to' + f' broadcast with input shape ({len(input_shape)}).' + ) + try: + return np.broadcast_shapes(tuple(input_shape), s_shape) + except ValueError as e: + raise ValueError( + f'Cannot broadcast shape {input_shape} with scale shape {s_shape}' + ) from e - @classmethod - def from_config(cls, config): - scale = config.scale - if hasattr(scale, 'data') and hasattr(scale, 'dtype'): - scale = np.array(scale.data, dtype=scale.dtype) - elif hasattr(scale, 'array'): - scale = np.asarray(scale.array) - return cls(scale=scale) + @types.check_layer + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Scales the input sequence by a learned or fixed scale.""" + return x.apply_values_masked(lambda v: v * self._scale) -class Add(types.PreservesType, types.StatelessPointwise): +class Add( + types.PreservesType, + types.StatelessPointwise, + spec.Add[types.Sequence, types.ShapeDType], +): """Adds a provided constant or array to the input.""" - def __init__(self, shift): - super().__init__() - if isinstance(shift, (int, float, complex)): - self._shift = shift - else: - self._shift = mx.array(np.asarray(shift)) + @dataclasses.dataclass(frozen=True) + class Config(spec.Add.Config): + """Configuration for Add layer.""" - @types.check_layer - def layer(self, x, *, constants=None): - s = self._shift - if isinstance(s, mx.array): - s = s.astype(x.dtype) - return x.apply_values(lambda v: v + s) + shift: Any + name: str | None = None - @classmethod - def from_config(cls, config): + @override + def make(self) -> 'Add': + """Creates the Add layer.""" + return Add(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config shift = config.shift if hasattr(shift, 'data') and hasattr(shift, 'dtype'): - shift = np.array(shift.data, dtype=shift.dtype) + self._shift = np.array(shift.data, dtype=shift.dtype) elif hasattr(shift, 'array'): - shift = np.asarray(shift.array) - return cls(shift=shift) + self._shift = np.asarray(shift.array) + else: + self._shift = shift + + @override + def get_output_shape( + self, + input_shape: types.ShapeLike, + *, + constants: types.Constants | None = None, + ) -> types.Shape: + del constants + s_shape = ( + () + if isinstance(self._shift, (int, float, complex)) + else self._shift.shape + ) + if len(s_shape) > len(input_shape): + raise ValueError( + f'Shift parameter has too many dimensions ({len(s_shape)}) to' + f' broadcast with input shape ({len(input_shape)}).' + ) + try: + return np.broadcast_shapes(tuple(input_shape), s_shape) + except ValueError as e: + raise ValueError( + f'Cannot broadcast shape {input_shape} with shift shape {s_shape}' + ) from e + + @types.check_layer + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Adds a learned or fixed shift to the input sequence.""" + return x.apply_values(lambda v: v + self._shift) # --------------------------------------------------------------------------- @@ -312,21 +680,34 @@ def from_config(cls, config): # --------------------------------------------------------------------------- -class MaskInvalid(types.PreservesType, types.StatelessPointwise): +class MaskInvalid( + types.PreservesType, types.StatelessPointwise, spec.MaskInvalid +): """Masks invalid timesteps to zero (or a specified value).""" - def __init__(self, mask_value=None): + @dataclasses.dataclass(frozen=True) + class Config(spec.MaskInvalid.Config): + """Configuration for MaskInvalid layer.""" + + mask_value: Any = None + name: str | None = None + + @override + def make(self) -> 'MaskInvalid': + """Creates the MaskInvalid layer.""" + return MaskInvalid(self) + + def __init__(self, config: Config): super().__init__() - self._mask_value = mask_value + self.config = config @types.check_layer - def layer(self, x, *, constants=None): - return x.mask_invalid(self._mask_value) - - @classmethod - def from_config(cls, config): - mask_value = getattr(config, 'mask_value', None) - return cls(mask_value=mask_value) + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Masks invalid values (NaN, Inf) in the input sequence.""" + return x.mask_invalid(self.config.mask_value) # --------------------------------------------------------------------------- @@ -334,23 +715,37 @@ def from_config(cls, config): # --------------------------------------------------------------------------- -class GatedUnit(types.PreservesType, types.Stateless): +class GatedUnit( + types.PreservesType, + types.Stateless, + spec.GatedUnit[types.Sequence, types.ShapeDType], +): """Computes a generalized Gated Unit, reducing input channels by 2x.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.GatedUnit.Config): + """Configuration for GatedUnit layer.""" + feature_activation: Callable | None = None gate_activation: Callable | None = None name: str | None = None + @override def make(self) -> 'GatedUnit': - return GatedUnit.from_config(self) + return GatedUnit(self) - def __init__(self, feature_activation=None, gate_activation=None): + def __init__(self, config: Config): super().__init__() - self._feature_activation = feature_activation - self._gate_activation = gate_activation + self.config = config + self._feature_activation = config.feature_activation + self._gate_activation = config.gate_activation + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + @override def get_output_shape(self, input_shape, *, constants=None): channels = input_shape[-1] if channels % 2 != 0: @@ -361,7 +756,11 @@ def get_output_shape(self, input_shape, *, constants=None): return tuple(input_shape[:-1]) + (channels // 2,) @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Applies a gated unit to the input sequence.""" feature, gate = mx.split(x.values, 2, axis=-1) if self._feature_activation: feature = self._feature_activation(feature) @@ -369,39 +768,50 @@ def layer(self, x, *, constants=None): gate = self._gate_activation(gate) return Sequence(feature * gate, x.mask) - @classmethod - def from_config(cls, config): - fa = init_mapping.map_activation(config.feature_activation) - ga = init_mapping.map_activation(config.gate_activation) - return cls(feature_activation=fa, gate_activation=ga) - -class GatedLinearUnit(GatedUnit): +class GatedLinearUnit( + GatedUnit, spec.GatedLinearUnit[types.Sequence, types.ShapeDType] +): """Computes a Gated Linear Unit, reducing input channels by 2x.""" - def __init__(self): - super().__init__( - feature_activation=None, - gate_activation=mx.sigmoid, - ) + @dataclasses.dataclass(frozen=True) + class Config(GatedUnit.Config, spec.GatedLinearUnit.Config): + """Configuration for GatedLinearUnit layer.""" - @classmethod - def from_config(cls, config): - return cls() + name: str | None = None + @override + def make(self) -> 'GatedLinearUnit': + """Create GatedLinearUnit layer.""" + return GatedLinearUnit( + GatedUnit.Config( + feature_activation=None, + gate_activation=mx.sigmoid, + name=self.name, + ) + ) -class GatedTanhUnit(GatedUnit): + +class GatedTanhUnit( + GatedUnit, spec.GatedTanhUnit[types.Sequence, types.ShapeDType] +): """Computes a Gated Tanh Unit, reducing input channels by 2x.""" - def __init__(self): - super().__init__( - feature_activation=mx.tanh, - gate_activation=mx.sigmoid, - ) + @dataclasses.dataclass(frozen=True) + class Config(GatedUnit.Config, spec.GatedTanhUnit.Config): + """Configuration for GatedTanhUnit layer.""" - @classmethod - def from_config(cls, config): - return cls() + name: str | None = None + + @override + def make(self) -> 'GatedTanhUnit': + return GatedTanhUnit( + GatedUnit.Config( + feature_activation=mx.tanh, + gate_activation=mx.sigmoid, + name=self.name, + ) + ) # --------------------------------------------------------------------------- @@ -409,21 +819,37 @@ def from_config(cls, config): # --------------------------------------------------------------------------- -class Flatten(types.PreservesType, types.StatelessPointwise): +class Flatten( + types.PreservesType, + types.StatelessPointwise, + spec.Flatten[types.Sequence, types.ShapeDType], +): """Flattens the channel dimensions of the input sequence.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(types.SequenceLayerConfig): + """Configuration for Flatten layer.""" + name: str | None = None + @override def make(self) -> 'Flatten': - return Flatten.from_config(self) + return Flatten(self) + + def __init__(self, config: Config): + super().__init__() + self.config = config + @override def get_output_shape(self, input_shape, *, constants=None): return (math.prod(input_shape),) @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Flattens the channel dimensions of the input sequence.""" batch_size, time = x.values.shape[:2] num_elements = math.prod(x.channel_shape) new_values = mx.reshape(x.values, (batch_size, time, num_elements)) @@ -431,61 +857,78 @@ def layer(self, x, *, constants=None): return MaskedSequence(new_values, x.mask) return Sequence(new_values, x.mask) - @classmethod - def from_config(cls, config): - return cls() - -class Reshape(types.PreservesType, types.Stateless): +class Reshape( + types.PreservesType, + types.Stateless, + spec.Reshape[types.Sequence, types.ShapeDType], +): """Reshapes the channels dimension of the input.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.Reshape.Config): + """Configuration for Reshape layer.""" + output_shape: tuple[int, ...] = () name: str | None = None def __post_init__(self): object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + @override def make(self) -> 'Reshape': - return Reshape.from_config(self) + return Reshape(self) - def __init__(self, output_shape): + def __init__(self, config: Config): super().__init__() - self._output_shape = tuple(output_shape) + self.config = config + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) def _validate(self, input_shape): + """Validates that input and output shapes have the same number of elements.""" in_elems = math.prod(input_shape) - out_elems = math.prod(self._output_shape) + + out_elems = math.prod(self.config.output_shape) if in_elems != out_elems: raise ValueError( - f'Reshape output_shape={self._output_shape} must have' + f'Reshape output_shape={self.config.output_shape} must have' f' the same number of elements as {input_shape=}.' ) + @override def get_output_shape(self, input_shape, *, constants=None): self._validate(input_shape) - return self._output_shape + return self.config.output_shape @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Reshapes the channel dimensions of the input sequence.""" self._validate(x.channel_shape) b, t = x.values.shape[:2] - new_values = mx.reshape(x.values, (b, t) + self._output_shape) + new_values = mx.reshape(x.values, (b, t) + self.config.output_shape) if isinstance(x, MaskedSequence): return MaskedSequence(new_values, x.mask) return Sequence(new_values, x.mask) - @classmethod - def from_config(cls, config): - return cls(output_shape=config.output_shape) - -class ExpandDims(types.PreservesType, types.Stateless): +class ExpandDims( + types.PreservesType, + types.Stateless, + spec.ExpandDims[types.Sequence, types.ShapeDType], +): """Expands channel dimensions of the input sequence.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.ExpandDims.Config): + """Configuration for ExpandDims layer.""" + axis: int | tuple[int, ...] = 0 name: str | None = None @@ -493,18 +936,26 @@ def __post_init__(self): if not isinstance(self.axis, int): object.__setattr__(self, 'axis', tuple(self.axis)) + @override def make(self) -> 'ExpandDims': - return ExpandDims.from_config(self) + return ExpandDims(self) - def __init__(self, axis): + def __init__(self, config: Config): super().__init__() - if isinstance(axis, int): - self._axis = (axis,) - else: - self._axis = tuple(axis) + self.config = config + self._axis: tuple[int, ...] = ( + (config.axis,) if isinstance(config.axis, int) else tuple(config.axis) + ) + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) def _normalize_axes(self, input_shape): + """Normalizes axes to positive indices.""" rank = len(input_shape) + dims = sorted(set(a + rank + 1 if a < 0 else a for a in self._axis)) for d in dims: if d < 0 or d > rank: @@ -513,6 +964,7 @@ def _normalize_axes(self, input_shape): ) return dims + @override def get_output_shape(self, input_shape, *, constants=None): dims = self._normalize_axes(input_shape) out = list(input_shape) @@ -521,47 +973,72 @@ def get_output_shape(self, input_shape, *, constants=None): return tuple(out) @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Expands the dimensions of the input sequence by inserting new axes.""" dims = [2 + d for d in self._normalize_axes(x.channel_shape)] new_values = mx.expand_dims(x.values, axis=dims) if isinstance(x, MaskedSequence): return MaskedSequence(new_values, x.mask) return Sequence(new_values, x.mask) - @classmethod - def from_config(cls, config): - return cls(axis=config.axis) - -class Squeeze(types.PreservesType, types.Stateless): +class Squeeze( + types.PreservesType, + types.Stateless, + spec.Squeeze[types.Sequence, types.ShapeDType], +): """Squeezes singleton channel dimensions of the input.""" - def __init__(self, axis=None): + @dataclasses.dataclass(frozen=True) + class Config(spec.Squeeze.Config): + """Configuration for Squeeze layer.""" + + axis: int | tuple[int, ...] | None = None + name: str | None = None + + @override + def make(self) -> 'Squeeze': + return Squeeze(self) + + def __init__(self, config: Config): super().__init__() - self._axis = axis + self.config = config + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) def _channel_squeeze_axes(self, input_shape): """Return channel-relative axes to squeeze.""" - if self._axis is None: + if self.config.axis is None: # Squeeze all singleton channel dims. return tuple(i for i, n in enumerate(input_shape) if n == 1) # If axis is given, it's in full-tensor coords. Convert to channel. - if isinstance(self._axis, int): - axes = (self._axis,) + if isinstance(self.config.axis, int): + axes = (self.config.axis,) else: - axes = tuple(self._axis) + axes = tuple(self.config.axis) return axes + @override def get_output_shape(self, input_shape, *, constants=None): squeeze_axes = self._channel_squeeze_axes(input_shape) out = [] for i, s in enumerate(input_shape): if i not in squeeze_axes: out.append(s) - return tuple(out) if out else (1,) + return tuple(out) @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Squeezes the dimensions of the input sequence by removing axes of size 1.""" ch_axes = self._channel_squeeze_axes(x.channel_shape) # Convert to full-tensor axes (offset by 2 for batch, time). full_axes = tuple(a + 2 for a in ch_axes) @@ -570,22 +1047,41 @@ def layer(self, x, *, constants=None): return MaskedSequence(new_values, x.mask) return Sequence(new_values, x.mask) - @classmethod - def from_config(cls, config): - return cls(axis=config.axis) - -class Transpose(types.PreservesType, types.Stateless): +class Transpose( + types.PreservesType, + types.Stateless, + spec.Transpose[types.Sequence, types.ShapeDType], +): """Permutes the channel axes of the input.""" - def __init__(self, axes=None): + @dataclasses.dataclass(frozen=True) + class Config(spec.Transpose.Config): + """Configuration for Transpose layer.""" + + axes: tuple[int, ...] | None = None + name: str | None = None + + @override + def make(self) -> 'Transpose': + return Transpose(self) + + def __init__(self, config: Config): super().__init__() - if axes is not None: - axes = tuple(axes) - self._axes = axes + self.config = config + self._axes: tuple[int, ...] | None = ( + tuple(config.axes) if config.axes is not None else None + ) + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) def _resolve_axes(self, input_shape): + """Resolves axes for transpose.""" input_axes = tuple(range(2, 2 + len(input_shape))) + if self._axes is None: return input_axes[::-1] sorted_axes = tuple(sorted(self._axes)) @@ -595,262 +1091,358 @@ def _resolve_axes(self, input_shape): ) return tuple(self._axes) + @override def get_output_shape(self, input_shape, *, constants=None): axes = self._resolve_axes(input_shape) return tuple(input_shape[a - 2] for a in axes) @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Transposes the channel dimensions of the input sequence.""" axes = self._resolve_axes(x.channel_shape) new_values = mx.transpose(x.values, (0, 1) + axes) if isinstance(x, MaskedSequence): return MaskedSequence(new_values, x.mask) return Sequence(new_values, x.mask) - @classmethod - def from_config(cls, config): - return cls(axes=config.axes) - # --------------------------------------------------------------------------- # Encoding # --------------------------------------------------------------------------- -class OneHot(types.Stateless): +class OneHot(types.Stateless, spec.OneHot[types.Sequence, types.ShapeDType]): """Computes one-hot vector of the input.""" - def __init__(self, depth, compute_dtype=mx.float32): + @dataclasses.dataclass(frozen=True) + class Config(spec.OneHot.Config): + """Configuration for OneHot layer.""" + + depth: int + compute_dtype: Any = mx.float32 + name: str | None = None + + @override + def make(self) -> 'OneHot': + return OneHot(self) + + def __init__(self, config: Config): super().__init__() - self._depth = depth - self._compute_dtype = compute_dtype + self.config = config + self._compute_dtype = _to_mx_dtype(config.compute_dtype) + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + @override def get_output_shape(self, input_shape, *, constants=None): - return tuple(input_shape) + (self._depth,) + return tuple(input_shape) + (self.config.depth,) - def get_output_dtype(self, input_dtype, *, constants=None): + @override + def get_output_dtype(self, input_dtype, *, constants=None) -> mx.Dtype: + assert self._compute_dtype is not None return self._compute_dtype @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Converts integer values to one-hot representations.""" + def one_hot_fn(v): indices = v.astype(mx.int32) - return mx.eye(self._depth, dtype=self._compute_dtype)[indices] + return mx.eye(self.config.depth, dtype=self._compute_dtype)[indices] return x.apply_values(one_hot_fn) - @classmethod - def from_config(cls, config): - from sequence_layers.mlx.init_mapping import _to_mx_dtype - return cls( - depth=config.depth, - compute_dtype=_to_mx_dtype(config.compute_dtype), - ) - - -class Embedding(types.Stateless): +class Embedding( + types.Stateless, spec.Embedding[types.Sequence, types.ShapeDType] +): """Computes embeddings of integer input codes. Backed by mlx.nn.Embedding. """ @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.Embedding.Config): + """Configuration for Embedding layer.""" + num_embeddings: int = 1 dimension: int = 1 compute_dtype: types.DType | None = None param_dtype: types.DType = mx.float32 name: str | None = None + @override def make(self) -> 'Embedding': - return Embedding.from_config(self) + return Embedding(self) - def __init__( - self, - *, - num_embeddings: int, - dimension: int, - param_dtype=mx.float32, - compute_dtype=None, - ): + def __init__(self, config: Config): super().__init__() - self.num_embeddings = num_embeddings - self.dimension = dimension - self._param_dtype = param_dtype - self.compute_dtype = compute_dtype - self._embedding = nn.Embedding(num_embeddings, dimension) + self.config = config + self._param_dtype = _to_mx_dtype(config.param_dtype) + self._compute_dtype = ( + _to_mx_dtype(config.compute_dtype) + if config.compute_dtype is not None + else None + ) + self._embedding = nn.Embedding(config.num_embeddings, config.dimension) + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + @override def get_output_shape(self, input_shape, *, constants=None): - return tuple(input_shape) + (self.dimension,) + return tuple(input_shape) + (self.config.dimension,) - def get_output_dtype(self, input_dtype, *, constants=None): - if self.compute_dtype is not None: - return self.compute_dtype + @override + def get_output_dtype(self, input_dtype, *, constants=None) -> mx.Dtype: + if self._compute_dtype is not None: + return self._compute_dtype + assert self._param_dtype is not None return self._param_dtype @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Embeds integer values using a learned embedding matrix.""" + def embed_fn(v): result = self._embedding(v.astype(mx.int32)) - if self.compute_dtype is not None: - result = result.astype(self.compute_dtype) + compute_dtype = self._compute_dtype + if compute_dtype is not None: + result = result.astype(compute_dtype) # type: ignore return result return x.apply_values(embed_fn) - @classmethod - def from_config(cls, config): - from sequence_layers.mlx.init_mapping import _to_mx_dtype - - compute_dtype = getattr(config, 'compute_dtype', None) - if compute_dtype is not None: - compute_dtype = _to_mx_dtype(compute_dtype) - return cls( - num_embeddings=config.num_embeddings, - dimension=config.dimension, - param_dtype=_to_mx_dtype(config.param_dtype), - compute_dtype=compute_dtype, - ) - # --------------------------------------------------------------------------- # Regularization # --------------------------------------------------------------------------- -class Dropout(types.PreservesType, types.StatelessPointwise): +class Dropout( + types.PreservesType, + types.StatelessPointwise, + spec.Dropout[types.Sequence, types.ShapeDType], +): """Dropout layer (pass-through during inference).""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.Dropout.Config): + """Configuration for Dropout layer.""" + rate: float = 0.0 broadcast_dims: tuple[int, ...] = () name: str | None = None + @override def make(self) -> 'Dropout': - return Dropout.from_config(self) + """Creates the Dropout layer.""" + return Dropout(self) - def __init__(self, rate=0.0): + def __init__(self, config: Config): super().__init__() - self._rate = rate + self.config = config @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Applies dropout to the input sequence.""" + if training: + raise NotImplementedError('Dropout training is not implemented in MLX.') # Inference-only: dropout is a no-op. return x - @classmethod - def from_config(cls, config): - return cls(rate=config.rate) - # --------------------------------------------------------------------------- # Sampling # --------------------------------------------------------------------------- -class Downsample1D(types.PreservesType, types.Stateless): +class Downsample1D( + types.PreservesType, + types.Stateless, + spec.Downsample1D[types.Sequence, types.ShapeDType], +): """A 1D downsampling layer.""" - def __init__(self, rate): + @dataclasses.dataclass(frozen=True) + class Config(spec.Downsample1D.Config): + """Configuration for Downsample1D layer.""" + + rate: int + name: str | None = None + + @override + def make(self) -> 'Downsample1D': + return Downsample1D(self) + + def __init__(self, config: Config): super().__init__() - self._rate = rate + self.config = config @property + @override def block_size(self): - return self._rate + return self.config.rate + + @property + @override + def output_ratio(self): + return Fraction(1, self.config.rate) + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + @override def get_output_shape(self, input_shape, *, constants=None): return tuple(input_shape) @types.check_layer - def layer(self, x, *, constants=None): - new_values = x.values[:, :: self._rate] - new_mask = x.mask[:, :: self._rate] + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Downsamples the input sequence along the time axis.""" + new_values = x.values[:, :: self.config.rate] + new_mask = x.mask[:, :: self.config.rate] if isinstance(x, MaskedSequence): return MaskedSequence(new_values, new_mask) return Sequence(new_values, new_mask) - @classmethod - def from_config(cls, config): - return cls(rate=config.rate) - -class Upsample1D(types.PreservesType, types.Stateless): +class Upsample1D( + types.PreservesType, + types.Stateless, + spec.Upsample1D[types.Sequence, types.ShapeDType], +): """A 1D upsampling layer.""" - def __init__(self, rate): + @dataclasses.dataclass(frozen=True) + class Config(spec.Upsample1D.Config): + """Configuration for Upsample1D layer.""" + + rate: int + name: str | None = None + + @override + def make(self) -> 'Upsample1D': + return Upsample1D(self) + + def __init__(self, config: Config): super().__init__() - self._rate = rate + self.config = config + + @property + @override + def output_ratio(self): + return Fraction(self.config.rate, 1) + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + @override def get_output_shape(self, input_shape, *, constants=None): return tuple(input_shape) @types.check_layer - def layer(self, x, *, constants=None): + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Upsamples the input sequence along the time axis.""" # Repeat each timestep `rate` times along the time axis. b, t = x.values.shape[:2] channel_shape = x.values.shape[2:] # [b, t, 1, ...] -> [b, t, rate, ...] -> [b, t*rate, ...] expanded = mx.expand_dims(x.values, axis=2) - tiled = mx.repeat(expanded, self._rate, axis=2) - new_values = mx.reshape(tiled, (b, t * self._rate) + channel_shape) + tiled = mx.repeat(expanded, self.config.rate, axis=2) + new_values = mx.reshape(tiled, (b, t * self.config.rate) + channel_shape) # Same for mask: [b, t] -> [b, t*rate] - new_mask = mx.repeat(mx.expand_dims(x.mask, axis=2), self._rate, axis=2) - new_mask = mx.reshape(new_mask, (b, t * self._rate)) + new_mask = mx.repeat( + mx.expand_dims(x.mask, axis=2), self.config.rate, axis=2 + ) + new_mask = mx.reshape(new_mask, (b, t * self.config.rate)) if isinstance(x, MaskedSequence): return MaskedSequence(new_values, new_mask) return Sequence(new_values, new_mask) - @classmethod - def from_config(cls, config): - return cls(rate=config.rate) - # --------------------------------------------------------------------------- # CheckpointName (identity for inference) # --------------------------------------------------------------------------- -class CheckpointName(types.PreservesType, types.StatelessPointwiseFunctor): +class CheckpointName( + types.PreservesType, + types.StatelessPointwiseFunctor, + spec.CheckpointName[types.Sequence, types.ShapeDType], +): """Identity pass-through (checkpoint naming is JAX-only).""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.CheckpointName.Config): + """Configuration for CheckpointName layer.""" + checkpoint_name: str = '' name: str | None = None + @override def make(self) -> 'CheckpointName': - return CheckpointName.from_config(self) + """Creates the CheckpointName layer.""" + return CheckpointName(self) - def __init__(self, checkpoint_name=''): + def __init__(self, config: Config): super().__init__() - self._checkpoint_name = checkpoint_name + self.config = config + + @override + def get_accumulated_input_latency(self, input_latency: int) -> int: + return super().get_accumulated_input_latency(input_latency) @property + @override def mask_required(self): return False - def fn(self, values, mask): + @override + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: + """Identity function for CheckpointName.""" return values, mask - @classmethod - def from_config(cls, config): - return cls(checkpoint_name=config.checkpoint_name) - # --------------------------------------------------------------------------- # Lambda # --------------------------------------------------------------------------- -class Lambda(types.Stateless): +class Lambda(types.Stateless, spec.Lambda[types.Sequence, types.ShapeDType]): """A SequenceLayer that wraps a Python callable.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): - fn: Callable = None + class Config(spec.Lambda.Config): + """Configuration for Lambda layer.""" + + fn: Callable sequence_input: bool = False mask_required: bool = True # Accepted for JAX compatibility but ignored by MLX Lambda. @@ -858,75 +1450,76 @@ class Config(_SequenceLayerConfig): expected_output_spec: object = None name: str | None = None + @override def make(self) -> 'Lambda': - return Lambda.from_config(self) + return Lambda(self) - def __init__(self, fn, *, sequence_input=False, mask_required=True, - expected_output_spec=None): + def __init__(self, config: Config): super().__init__() - self._fn = fn - self._sequence_input = sequence_input - self._mask_required = mask_required - self._expected_output_spec = expected_output_spec + self.config = config self._cached_output_spec = None + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + def _probe_output(self, input_shape, input_dtype): """Probe the function with a dummy to infer output shape/dtype.""" - if self._expected_output_spec is not None: - return self._expected_output_spec + if self.config.expected_output_spec is not None: + return self.config.expected_output_spec if self._cached_output_spec is not None: return self._cached_output_spec try: dummy_values = mx.zeros((1, 1) + tuple(input_shape), dtype=input_dtype) dummy_mask = mx.ones((1, 1), dtype=mx.bool_) - if self._sequence_input: - result = self._fn(Sequence(dummy_values, dummy_mask)) + assert self.config.fn is not None + if self.config.sequence_input: + result = self.config.fn(Sequence(dummy_values, dummy_mask)) out_shape = result.values.shape[2:] out_dtype = result.values.dtype else: - out_values = self._fn(dummy_values) + out_values = self.config.fn(dummy_values) out_shape = out_values.shape[2:] out_dtype = out_values.dtype - self._cached_output_spec = bt.ShapeDType(out_shape, out_dtype) + self._cached_output_spec = types.ShapeDType(out_shape, out_dtype) return self._cached_output_spec - except Exception: + except Exception: # pylint: disable=broad-exception-caught return None + @override def get_output_shape(self, input_shape, *, constants=None): - spec = self._probe_output(input_shape, mx.float32) - if spec is not None: - return tuple(spec.shape) + out_spec = self._probe_output(input_shape, mx.float32) + if out_spec is not None: + return tuple(out_spec.shape) return tuple(input_shape) + @override def get_output_dtype(self, input_dtype, *, constants=None): - spec = self._probe_output((1,), input_dtype) - if spec is not None: - return spec.dtype + out_spec = self._probe_output((1,), input_dtype) + if out_spec is not None: + return out_spec.dtype return input_dtype - def layer(self, x, *, constants=None): - if self._sequence_input: - result = self._fn(x) + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Applies a custom Python callable to the input sequence.""" + assert self.config.fn is not None + if self.config.sequence_input: + result = self.config.fn(x) if not isinstance(result, (Sequence, MaskedSequence)): raise ValueError( 'Lambda with sequence_input=True must return a Sequence, ' f'got {type(result)}' ) return result - else: - new_values = self._fn(x.values) - if self._mask_required or not isinstance(x, MaskedSequence): - return Sequence(new_values, x.mask) - return MaskedSequence(new_values, x.mask) - @classmethod - def from_config(cls, config): - return cls( - fn=config.fn, - sequence_input=config.sequence_input, - mask_required=config.mask_required, - expected_output_spec=getattr(config, 'expected_output_spec', None), - ) + new_values = self.config.fn(x.values) + if self.config.mask_required or not isinstance(x, MaskedSequence): + return Sequence(new_values, x.mask) + return MaskedSequence(new_values, x.mask) # --------------------------------------------------------------------------- @@ -934,37 +1527,88 @@ def from_config(cls, config): # --------------------------------------------------------------------------- -class Logging(types.PreservesType, types.StatelessPointwise): +class Logging( + types.PreservesType, + types.StatelessPointwise, + spec.Logging[types.Sequence, types.ShapeDType], +): """Logs input info and returns the input unchanged.""" @dataclasses.dataclass(frozen=True) - class Config(_SequenceLayerConfig): + class Config(spec.Logging.Config): + """Configuration for Logging layer.""" + prefix: str = '' dump_tensors: bool = False name: str | None = None + @override def make(self) -> 'Logging': - return Logging.from_config(self) + """Creates the Logging layer.""" + return Logging(self) - def __init__(self, prefix='', dump_tensors=False): + def __init__(self, config: Config): super().__init__() - self._prefix = prefix - self._dump_tensors = dump_tensors + self.config = config + + @override + def get_initial_state( + self, + batch_size: int, + input_spec: types.ChannelSpec, + *, + training: bool, + constants: types.Constants | None = None, + ) -> types.State: + if self.config.dump_tensors: + logging.info( + f'{self.config.prefix} get_initial_state(): batch_size={batch_size}, ' + f'input_spec={input_spec}, training={training}, constants={constants}' + ) + else: + logging.info( + f'{self.config.prefix} get_initial_state(): batch_size={batch_size}, ' + f'input_spec={input_spec}, training={training}' + ) + return super().get_initial_state( + batch_size, input_spec, training=training, constants=constants + ) + + @override + def step( + self, + x: types.Sequence, + state: types.State, + *, + training: bool, + constants: types.Constants | None = None, + ) -> tuple[types.Sequence, types.State]: + if self.config.dump_tensors: + logging.info( + f'{self.config.prefix} step(): x={x.values}, state={state}, ' + f'training={training}, constants={constants}' + ) + else: + logging.info( + f'{self.config.prefix} step(): x.shape={x.shape}, x.dtype={x.dtype}, ' + f'state={state}, training={training}' + ) + return super().step(x, state, training=training, constants=constants) @types.check_layer - def layer(self, x, *, constants=None): - if self._dump_tensors: - print(f'{self._prefix} layer(): x={x.values}') + @override + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + """Logs the input sequence values for debugging.""" + if self.config.dump_tensors: + logging.info( + f'{self.config.prefix} layer(): x={x.values}, training={training},' + f' constants={constants}' + ) else: - print( - f'{self._prefix} layer(): x.shape={x.shape}, ' - f'x.dtype={x.dtype}' + logging.info( + f'{self.config.prefix} layer(): x.shape={x.shape}, x.dtype={x.dtype},' + f' training={training}' ) return x - - @classmethod - def from_config(cls, config): - return cls( - prefix=config.prefix, - dump_tensors=config.dump_tensors, - ) diff --git a/sequence_layers/mlx/simple_test.py b/sequence_layers/mlx/simple_test.py index 08b19da..90bd290 100644 --- a/sequence_layers/mlx/simple_test.py +++ b/sequence_layers/mlx/simple_test.py @@ -1,507 +1,195 @@ """Tests for simple MLX sequence layers.""" -import mlx.core as mx -import numpy as np +from typing import override + from absl.testing import absltest -from absl.testing import parameterized -from sequence_layers.mlx import basic_types as bt +import numpy as np + from sequence_layers.mlx import simple from sequence_layers.mlx import test_utils +from sequence_layers.specs import simple_behaviors as spec -class IdentityTest(parameterized.TestCase): +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): + pass - def test_layer(self): - layer = simple.Identity() - test_utils.verify_contract(self, layer, (4,)) + +class IdentityTest(test_utils.SequenceLayerTest, spec.IdentityTest): def test_preserves_values(self): - layer = simple.Identity() - x = test_utils.random_sequence(2, 3, 4) - y = layer.layer(x) + layer = simple.Identity.Config().make() + x = self.random_sequence(2, 3, 4) + y = layer.layer(x, training=False) np.testing.assert_array_equal(y.values, x.values) np.testing.assert_array_equal(y.mask, x.mask) -class ReluTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Relu() - test_utils.verify_contract(self, layer, (4,)) - - def test_negative_zeroed(self): - layer = simple.Relu() - values = mx.array([[-1.0, 0.5, -0.3, 2.0]]).reshape(1, 1, 4) - mask = mx.ones((1, 1), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - expected = mx.array([[[0.0, 0.5, 0.0, 2.0]]]) - np.testing.assert_allclose(y.values, expected, atol=1e-6) - - -class GeluTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Gelu() - test_utils.verify_contract(self, layer, (4,)) - - -class SwishTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Swish() - test_utils.verify_contract(self, layer, (4,)) - - -class TanhTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Tanh() - test_utils.verify_contract(self, layer, (4,)) - - def test_values(self): - layer = simple.Tanh() - values = mx.array([[[0.0, 1.0, -1.0, 100.0]]]) - mask = mx.ones((1, 1), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - np.testing.assert_allclose( - y.values, np.tanh([[[0.0, 1.0, -1.0, 100.0]]]), atol=1e-5 - ) - - -class SigmoidTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Sigmoid() - test_utils.verify_contract(self, layer, (4,)) - - -class LeakyReluTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.LeakyRelu(negative_slope=0.2) - test_utils.verify_contract(self, layer, (4,)) - - def test_negative_slope(self): - layer = simple.LeakyRelu(negative_slope=0.1) - values = mx.array([[[-2.0, 0.5, -1.0, 3.0]]]) - mask = mx.ones((1, 1), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - expected = mx.array([[[-0.2, 0.5, -0.1, 3.0]]]) - np.testing.assert_allclose(y.values, expected, atol=1e-6) - - -class EluTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Elu() - test_utils.verify_contract(self, layer, (4,)) - - -class SoftmaxTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Softmax() - test_utils.verify_contract(self, layer, (4,)) - - def test_sums_to_one(self): - layer = simple.Softmax(axis=-1) - values = mx.array([[[1.0, 2.0, 3.0, 4.0]]]) - mask = mx.ones((1, 1), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - np.testing.assert_allclose(float(mx.sum(y.values)), 1.0, atol=1e-5) - - -class SoftplusTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Softplus() - test_utils.verify_contract(self, layer, (4,)) - +class PointwiseMathTest(test_utils.SequenceLayerTest, spec.PointwiseMathTest): -class CastTest(parameterized.TestCase): + @override + def make_layer(self, layer_name): + layer_cls = getattr(self.sl, layer_name) + return layer_cls(layer_cls.Config()) - def test_layer(self): - layer = simple.Cast(dtype=mx.float16) - test_utils.verify_contract(self, layer, (4,), atol=1e-3, rtol=1e-3) - def test_cast(self): - layer = simple.Cast(dtype=mx.float16) - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) - self.assertEqual(y.dtype, mx.float16) +class CastTest(test_utils.SequenceLayerTest, spec.CastTest): + pass -class ScaleTest(parameterized.TestCase): +class ScaleTest(test_utils.SequenceLayerTest, spec.ScaleTest): + pass - def test_layer(self): - layer = simple.Scale(scale=2.0) - test_utils.verify_contract(self, layer, (4,)) - def test_scalar(self): - layer = simple.Scale(scale=2.0) - values = mx.array([[[1.0, 2.0, 3.0]]]) - mask = mx.ones((1, 1), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - expected = mx.array([[[2.0, 4.0, 6.0]]]) - np.testing.assert_allclose(y.values, expected, atol=1e-6) +class AddTest(test_utils.SequenceLayerTest, spec.AddTest): + pass -class AddTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Add(shift=1.0) - test_utils.verify_contract(self, layer, (4,)) +class MaskInvalidTest(test_utils.SequenceLayerTest, spec.MaskInvalidTest): + pass - def test_scalar(self): - layer = simple.Add(shift=10.0) - values = mx.array([[[1.0, 2.0, 3.0]]]) - mask = mx.ones((1, 1), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - expected = mx.array([[[11.0, 12.0, 13.0]]]) - np.testing.assert_allclose(y.values, expected, atol=1e-6) - - -class MaskInvalidTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.MaskInvalid() - test_utils.verify_contract(self, layer, (4,)) - def test_masks_to_zero(self): - layer = simple.MaskInvalid() - values = mx.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) - mask = mx.array([[True, False, True]]) - x = bt.Sequence(values, mask) - y = layer.layer(x) - expected = mx.array([[[1.0, 2.0], [0.0, 0.0], [5.0, 6.0]]]) - np.testing.assert_allclose(y.values, expected, atol=1e-6) +class GatedUnitTest(test_utils.SequenceLayerTest, spec.GatedUnitTest): + pass -class GatedUnitTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.GatedUnit() - test_utils.verify_contract(self, layer, (8,)) - - def test_with_activations(self): - import mlx.nn as nn - - layer = simple.GatedUnit( - feature_activation=nn.relu, gate_activation=nn.sigmoid - ) - test_utils.verify_contract(self, layer, (8,)) - - -class GatedLinearUnitTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.GatedLinearUnit() - test_utils.verify_contract(self, layer, (8,)) - - def test_halves_channels(self): - layer = simple.GatedLinearUnit() - self.assertEqual(layer.get_output_shape((8,)), (4,)) - - -class GatedTanhUnitTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.GatedTanhUnit() - test_utils.verify_contract(self, layer, (8,)) - - -class FlattenTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Flatten() - test_utils.verify_contract(self, layer, (2, 3, 4)) +class FlattenTest(test_utils.SequenceLayerTest, spec.FlattenTest): + pass - def test_flatten(self): - layer = simple.Flatten() - self.assertEqual(layer.get_output_shape((2, 3, 4)), (24,)) - -class ReshapeTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Reshape(output_shape=(2, 6)) - test_utils.verify_contract(self, layer, (12,)) - - def test_reshape(self): - layer = simple.Reshape(output_shape=(2, 6)) - x = test_utils.random_sequence(1, 3, 12) - y = layer.layer(x) - self.assertEqual(y.channel_shape, (2, 6)) +class ReshapeTest(test_utils.SequenceLayerTest, spec.ReshapeTest): def test_mismatch_raises(self): - layer = simple.Reshape(output_shape=(5,)) + layer = simple.Reshape.Config(output_shape=(5,)).make() + with self.assertRaises(ValueError): layer.get_output_shape((12,)) -class ExpandDimsTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.ExpandDims(axis=-1) - test_utils.verify_contract(self, layer, (4,)) - - def test_expand(self): - layer = simple.ExpandDims(axis=0) - self.assertEqual(layer.get_output_shape((4, 8)), (1, 4, 8)) - - def test_layer_values(self): - layer = simple.ExpandDims(axis=-1) - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) - self.assertEqual(y.channel_shape, (4, 1)) +class ExpandDimsTest(test_utils.SequenceLayerTest, spec.ExpandDimsTest): + pass -class SqueezeTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Squeeze() - test_utils.verify_contract(self, layer, (4, 1)) - - def test_squeeze(self): - layer = simple.Squeeze() - x = bt.MaskedSequence( - mx.ones((1, 3, 1, 4, 1)), - mx.ones((1, 3), dtype=mx.bool_), - ) - y = layer.layer(x) - self.assertEqual(y.channel_shape, (4,)) +class SqueezeTest(test_utils.SequenceLayerTest, spec.SqueezeTest): + pass -class TransposeTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Transpose() - test_utils.verify_contract(self, layer, (2, 3, 4)) +class TransposeTest(test_utils.SequenceLayerTest, spec.TransposeTest): def test_reverse(self): - layer = simple.Transpose() + layer = simple.Transpose.Config().make() self.assertEqual(layer.get_output_shape((2, 3, 4)), (4, 3, 2)) def test_explicit(self): - layer = simple.Transpose(axes=(3, 2, 4)) - self.assertEqual(layer.get_output_shape((5, 6, 7)), (6, 5, 7)) - - -class OneHotTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.OneHot(depth=5) - x = bt.MaskedSequence( - mx.array([[0, 2, 4]]), - mx.ones((1, 3), dtype=mx.bool_), - ) - y = layer.layer(x) - self.assertEqual(y.shape, (1, 3, 5)) - # Check that index 0 -> [1,0,0,0,0] - np.testing.assert_allclose(np.array(y.values[0, 0]), [1, 0, 0, 0, 0]) - - -class EmbeddingTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Embedding(num_embeddings=10, dimension=8) - x = bt.MaskedSequence( - mx.array([[1, 3, 5]]), - mx.ones((1, 3), dtype=mx.bool_), - ) - y = layer.layer(x) - self.assertEqual(y.shape, (1, 3, 8)) - - def test_output_shape(self): - layer = simple.Embedding(num_embeddings=10, dimension=8) - self.assertEqual(layer.get_output_shape(()), (8,)) - self.assertEqual(layer.get_output_shape((3,)), (3, 8)) + layer = simple.Transpose.Config(axes=(3, 2, 4)).make() + self.assertEqual(layer.get_output_shape((5, 6, 7)), (6, 5, 7)) -class DropoutTest(parameterized.TestCase): - def test_layer(self): - layer = simple.Dropout(rate=0.5) - test_utils.verify_contract(self, layer, (4,)) +class OneHotTest(test_utils.SequenceLayerTest, spec.OneHotTest): + pass + + +class EmbeddingTest(test_utils.SequenceLayerTest, spec.EmbeddingTest): + pass + + +class DropoutTest(test_utils.SequenceLayerTest, spec.DropoutTest): + pass + + +class Downsample1DTest(test_utils.SequenceLayerTest, spec.Downsample1DTest): + pass + + +class Upsample1DTest(test_utils.SequenceLayerTest, spec.Upsample1DTest): + pass + + +# class BackendDispatchTest(parameterized.TestCase): +# """Test config.make(backend='mlx') for simple layers.""" +# +# def test_identity(self): +# import sequence_layers.mlx # Register backends. +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.Identity.Config() +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.Identity) +# +# def test_relu(self): +# import sequence_layers.mlx +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.Relu.Config() +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.Relu) +# +# def test_tanh(self): +# import sequence_layers.mlx +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.Tanh.Config() +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.Tanh) +# +# def test_gated_linear_unit(self): +# import sequence_layers.mlx +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.GatedLinearUnit.Config() +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.GatedLinearUnit) +# +# def test_reshape(self): +# import sequence_layers.mlx +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.Reshape.Config(output_shape=(2, 3)) +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.Reshape) +# +# def test_downsample(self): +# import sequence_layers.mlx +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.Downsample1D.Config(rate=2) +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.Downsample1D) + + +class CheckpointNameTest(test_utils.SequenceLayerTest, spec.CheckpointNameTest): + + def test_layer(self): + layer = simple.CheckpointName.Config(checkpoint_name='test').make() + + x = self.random_sequence(2, 3, 4) + self.verify_contract(layer, x) def test_passthrough(self): - layer = simple.Dropout(rate=0.5) - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) - # Inference-only: should be identity. - np.testing.assert_array_equal(y.values, x.values) - - -class Downsample1DTest(parameterized.TestCase): - - def test_verify_contract(self): - layer = simple.Downsample1D(rate=2) - test_utils.verify_contract(self, layer, (4,)) - - def test_layer(self): - layer = simple.Downsample1D(rate=2) - x = test_utils.random_sequence(1, 6, 4) - y = layer.layer(x) - self.assertEqual(y.shape, (1, 3, 4)) - - def test_values(self): - layer = simple.Downsample1D(rate=3) - values = mx.arange(12).reshape(1, 6, 2).astype(mx.float32) - mask = mx.ones((1, 6), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - # Should keep timesteps 0, 3. - np.testing.assert_array_equal(y.values, values[:, ::3]) - - -class Upsample1DTest(parameterized.TestCase): + layer = simple.CheckpointName.Config(checkpoint_name='test').make() - def test_verify_contract(self): - layer = simple.Upsample1D(rate=3) - test_utils.verify_contract(self, layer, (4,)) - - def test_layer(self): - layer = simple.Upsample1D(rate=3) - x = test_utils.random_sequence(1, 4, 2) - y = layer.layer(x) - self.assertEqual(y.shape, (1, 12, 2)) - - def test_values(self): - layer = simple.Upsample1D(rate=2) - values = mx.array([[[1.0, 2.0], [3.0, 4.0]]]) - mask = mx.ones((1, 2), dtype=mx.bool_) - x = bt.MaskedSequence(values, mask) - y = layer.layer(x) - expected = mx.array([[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]]) - np.testing.assert_allclose(y.values, expected) - self.assertEqual(y.mask.shape, (1, 4)) - - -class BackendDispatchTest(parameterized.TestCase): - """Test config.make(backend='mlx') for simple layers.""" - - def test_identity(self): - import sequence_layers.mlx # Register backends. - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.Identity.Config() - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Identity) - - def test_relu(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.Relu.Config() - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Relu) - - def test_tanh(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.Tanh.Config() - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Tanh) - - def test_gated_linear_unit(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.GatedLinearUnit.Config() - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.GatedLinearUnit) - - def test_reshape(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.Reshape.Config(output_shape=(2, 3)) - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Reshape) - - def test_downsample(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.Downsample1D.Config(rate=2) - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Downsample1D) - - -class CheckpointNameTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.CheckpointName(checkpoint_name='test') - test_utils.verify_contract(self, layer, (4,)) - - def test_passthrough(self): - layer = simple.CheckpointName(checkpoint_name='test') - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) + x = self.random_sequence(1, 3, 4) + y = layer.layer(x, training=False) np.testing.assert_array_equal(y.values, x.values) np.testing.assert_array_equal(y.mask, x.mask) - def test_from_config(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.CheckpointName.Config(checkpoint_name='test') - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.CheckpointName) - - -class LambdaTest(parameterized.TestCase): + # def test_from_config(self): - def test_values_fn(self): - layer = simple.Lambda(fn=lambda v: v * 2.0) - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) - np.testing.assert_allclose(y.values, x.values * 2.0, atol=1e-6) - def test_sequence_fn(self): - def double_seq(s): - return bt.Sequence(s.values * 2.0, s.mask) +# import sequence_layers.mlx +# from sequence_layers.jax import simple as jax_simple +# +# config = jax_simple.CheckpointName.Config(checkpoint_name='test') +# mlx_layer = config.make(backend='mlx') +# self.assertIsInstance(mlx_layer, simple.CheckpointName) - layer = simple.Lambda(fn=double_seq, sequence_input=True) - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) - np.testing.assert_allclose(y.values, x.values * 2.0, atol=1e-6) - def test_from_config(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - - config = jax_simple.Lambda.Config(fn=lambda v: v) - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Lambda) - - -class LoggingTest(parameterized.TestCase): - - def test_layer(self): - layer = simple.Logging(prefix='test') - test_utils.verify_contract(self, layer, (4,)) - - def test_passthrough(self): - layer = simple.Logging() - x = test_utils.random_sequence(1, 3, 4) - y = layer.layer(x) - np.testing.assert_array_equal(y.values, x.values) +class LambdaTest(test_utils.SequenceLayerTest, spec.LambdaTest): + """Test behavior of Lambda layer.""" - def test_from_config(self): - import sequence_layers.mlx - from sequence_layers.jax import simple as jax_simple - config = jax_simple.Logging.Config(prefix='test') - mlx_layer = config.make(backend='mlx') - self.assertIsInstance(mlx_layer, simple.Logging) +class LoggingTest(test_utils.SequenceLayerTest, spec.LoggingTest): + """Test behavior of Logging layer.""" if __name__ == '__main__': diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py index 3af90ca..f031ce7 100644 --- a/sequence_layers/mlx/test_utils.py +++ b/sequence_layers/mlx/test_utils.py @@ -10,7 +10,6 @@ from sequence_layers import specs from sequence_layers.mlx import types -import sequence_layers.mlx as sl from sequence_layers.specs import test_utils as spec Sequence = types.Sequence @@ -84,7 +83,9 @@ def _mask_and_pad_to_max_length( class SequenceLayerTest(spec.SequenceLayerTest): """Base class for MLX SequenceLayer tests.""" - sl = sl # pyrefly: ignore[bad-assignment] # module-as-protocol + import sequence_layers.mlx as sl_module # pylint: disable=import-outside-toplevel + + sl = sl_module # pyrefly: ignore[bad-assignment] # module-as-protocol @override def setUp(self): @@ -93,6 +94,15 @@ def setUp(self): # MLX doesn't have a global seed, but we can set numpy seed. np.random.seed(123456789) + @override + def get_variables(self, layer: Any) -> dict[str, Any]: + + return layer.parameters() + + @override + def init_layer(self, layer, x, bind_only=False): + return layer + @override def random_sequence( self, @@ -111,9 +121,33 @@ def random_sequence( time = dims[1] shape = dims[2:] - values_np = np.random.normal(size=(batch_size, time) + shape).astype( - np.float32 - ) + if dtype is not None: + if dtype == np.float32: + dtype = mx.float32 + elif dtype == np.float16: + dtype = mx.float16 + elif dtype == np.int32: + dtype = mx.int32 + elif dtype == np.bool_: + dtype = mx.bool_ + + if dtype is not None and dtype in ( + mx.int32, + mx.int16, + mx.int8, + mx.uint32, + mx.uint16, + mx.uint8, + ): + values_np = np.random.randint( + low if low is not None else 0, + high if high is not None else 10, + size=(batch_size, time) + shape, + ) + else: + values_np = np.random.normal(size=(batch_size, time) + shape).astype( + np.float32 + ) values = mx.array(values_np, dtype=dtype or mx.float32) mask_np = np.ones((batch_size, time), dtype=bool) @@ -121,6 +155,16 @@ def random_sequence( return types.Sequence(values, mask) + @override + def assertEqual(self, first, second, msg=None): + """Override to handle MLX vs NumPy dtypes.""" + if isinstance(first, mx.Dtype) and isinstance(second, (type, np.dtype)): + first_str = str(first).rsplit('.', maxsplit=1)[-1] + second_str = np.dtype(second).name + if first_str == second_str: + return + super().assertEqual(first, second, msg) + @override def assertAllEqual(self, x, y): """Asserts that two arrays are equal.""" @@ -164,7 +208,7 @@ def _step_by_step( outputs_masks = [] for t in range(0, time, block_size): - x_block = sl.Sequence( + x_block = types.Sequence( x.values[:, t : t + block_size], x.mask[:, t : t + block_size], ) @@ -184,7 +228,7 @@ def _step_by_step( y_values = mx.concatenate(outputs_values, axis=1) y_mask = mx.concatenate(outputs_masks, axis=1) - return sl.Sequence(y_values, y_mask), state + return types.Sequence(y_values, y_mask), state @override # pyrefly: ignore[bad-override] @@ -246,7 +290,9 @@ def assertSequencesClose(self, x: Any, y: Any, **kwargs) -> None: np.testing.assert_array_equal(mask_x, mask_y) -class ModuleSpecTest(SequenceLayerTest, spec.ModuleSpecTest): +class ModuleSpecTest( + SequenceLayerTest, spec.ModuleSpecTest +): # pyrefly: ignore[invalid-inheritance] @override def module_spec_pairs(self, backend_sl: specs.ModuleSpec): diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index 2c82766..7e75558 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -12,6 +12,7 @@ import jaxtyping as jt from mlx import nn import mlx.core as mx +import numpy as np from sequence_layers.specs import types as spec @@ -39,6 +40,32 @@ InputT = TypeVar('InputT', bound='Sequence') OutputT = TypeVar('OutputT', bound='Sequence') + +def _to_tuple(x: complex | list[Any]) -> complex | tuple[Any, ...]: + """Replaces lists in a pytree of complex with tuples.""" + if isinstance(x, list): + return tuple(_to_tuple(i) for i in x) + return x + + +@dataclasses.dataclass(frozen=True) +class HashableArray: + """Hashable multidimensional array of tuples.""" + + data: complex | tuple[Any, ...] + dtype: Any + + @classmethod + def from_array(cls, x: Any) -> 'HashableArray': + """Creates a HashableArray from a numpy-like array.""" + x = np.asarray(x) + return HashableArray(_to_tuple(x.tolist()), x.dtype) + + def to_array(self) -> Any: + """Converts HashableArray back to a numpy array.""" + return np.asarray(self.data, dtype=self.dtype) + + __all__ = ( # go/keep-sorted start 'ChannelSpec', @@ -47,6 +74,7 @@ 'Emits', 'Emitting', 'ExpandedMaskT', + 'HashableArray', 'LengthsT', 'MASK_DTYPE', 'MaskT', @@ -169,7 +197,10 @@ def from_values(cls, values: ValuesT) -> 'MaskedSequence': """Returns a MaskedSequence assuming every timestep is valid.""" if values.ndim < 2: raise ValueError(f'Expected {values.ndim=} to be at least 2.') - return MaskedSequence(values, mx.ones(values.shape[:2], dtype=mx.bool_)) + array_values = values if isinstance(values, mx.array) else mx.array(values) + return MaskedSequence( + array_values, mx.ones(array_values.shape[:2], dtype=mx.bool_) + ) @classmethod @override @@ -354,16 +385,19 @@ def mask_invalid( mask_value: complex | None = None, ) -> 'Sequence': """Returns a sequence with invalid timesteps replaced.""" + values = sequence.values + if not isinstance(values, mx.array): + values = mx.array(values) expanded_mask = sequence.expanded_mask() if mask_value is None: - masked_values = mx.zeros_like(sequence.values) + masked_values = mx.zeros_like(values) result_type: type[Sequence] = MaskedSequence else: masked_values = mx.full( - sequence.values.shape, mask_value, sequence.values.dtype # type: ignore[arg-type] + values.shape, mask_value, values.dtype # type: ignore[arg-type] ) result_type = Sequence - masked_values = mx.where(expanded_mask, sequence.values, masked_values) + masked_values = mx.where(expanded_mask, values, masked_values) return result_type(masked_values, sequence.mask) @@ -517,6 +551,20 @@ def output_ratio(self) -> fractions.Fraction: def supports_step(self) -> bool: return True + def get_output_shape_for_sequence( + self, + x: Sequence, + *, + constants: Constants | None = None, + ) -> Shape: + """Returns the output shape this layer produces for the provided Sequence.""" + return self.get_output_shape(x.channel_shape, constants=constants) + + @property + def name(self) -> str | None: + """Returns the name of the layer.""" + return self.config.name if hasattr(self, 'config') else None + @property @override def input_latency(self) -> int: @@ -770,7 +818,6 @@ class SequenceLayer( nn.Module, Steppable, spec.SequenceLayer[Sequence, Sequence, ChannelSpec], - metaclass=abc.ABCMeta, ): """Base Module for Sequence Layers.""" @@ -797,7 +844,6 @@ def copy(self, **kwargs) -> Self: class PreservesType( SequenceLayer, spec.PreservesType[Sequence, Sequence, ChannelSpec], - metaclass=abc.ABCMeta, ): """A mix-in for layers that do not change the input dtype.""" @@ -815,7 +861,6 @@ def get_output_dtype( class PreservesShape( SequenceLayer, spec.PreservesShape[Sequence, Sequence, ChannelSpec], - metaclass=abc.ABCMeta, ): """A mix-in for layers that do not change the input shape.""" @@ -906,10 +951,14 @@ class StatelessPointwise( PreservesShape, Stateless, spec.StatelessPointwise[Sequence, Sequence, ChannelSpec], - metaclass=abc.ABCMeta, ): """A SequenceLayer that has no state and operates pointwise on its input.""" + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + class StatelessPointwiseFunctor( StatelessPointwise, @@ -919,12 +968,12 @@ class StatelessPointwiseFunctor( @abc.abstractmethod @override - def fn(self, values: ValuesT, mask: MaskT) -> tuple[ValuesT, MaskT]: + def fn(self, values: mx.array, mask: mx.array) -> tuple[mx.array, mx.array]: """Transforms each scalar in values independently.""" @property @override - def mask_required(self): + def mask_required(self) -> bool: """Returns true if fn can change the sequence's masked state. If fn(0) -> 0, then mask_required() is False. diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index 1b7aadd..319bf1d 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -52,5 +52,9 @@ class StatelessPointwiseFunctorTest( pass +class HashableArrayTest(test_utils.SequenceLayerTest, spec.HashableArrayTest): + pass + + if __name__ == '__main__': absltest.main() From 7f0659ffdeb25ed1eb27ecb619e62de7b5dbedef Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Sat, 18 Apr 2026 22:29:46 -0700 Subject: [PATCH 5/5] chore: Fix abstract method warnings in MLX types and unused imports in specs --- sequence_layers/mlx/types.py | 4 ++++ sequence_layers/specs/simple.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index 7e75558..7aa3078 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -814,6 +814,7 @@ def get_output_spec( # --------------------------------------------------------------------------- +# pylint: disable=abstract-method class SequenceLayer( nn.Module, Steppable, @@ -841,6 +842,7 @@ def copy(self, **kwargs) -> Self: # --------------------------------------------------------------------------- +# pylint: disable=abstract-method class PreservesType( SequenceLayer, spec.PreservesType[Sequence, Sequence, ChannelSpec], @@ -858,6 +860,7 @@ def get_output_dtype( return input_dtype +# pylint: disable=abstract-method class PreservesShape( SequenceLayer, spec.PreservesShape[Sequence, Sequence, ChannelSpec], @@ -947,6 +950,7 @@ def step( return self.layer(x, training=training, constants=constants), state +# pylint: disable=abstract-method class StatelessPointwise( PreservesShape, Stateless, diff --git a/sequence_layers/specs/simple.py b/sequence_layers/specs/simple.py index 5d6faad..37572ea 100644 --- a/sequence_layers/specs/simple.py +++ b/sequence_layers/specs/simple.py @@ -7,11 +7,10 @@ import abc import dataclasses -from typing import (Any, Callable, Generic, Protocol, runtime_checkable, +from typing import (Any, Callable, Protocol, runtime_checkable, Sequence, TypeVar) from sequence_layers.specs import types as types_spec -from sequence_layers.specs.types import HashableArray # --------------------------------------------------------------------------- # Activation Functions (StatelessPointwiseFunctor)