From 2425c3b7892b8efe15a094b63df15714e27d2b65 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 17 Apr 2026 01:35:40 -0700 Subject: [PATCH 1/5] feat: Add MLX dense.py entries from https://github.com/DBraun/sequence-layers/commit/80daa69bcb5a5580ff9fb73d13e416a1813b1462 Co-authored-by: David Braun <2096055+DBraun@users.noreply.github.com> --- sequence_layers/mlx/dense.py | 356 ++++++++++++++++++++++++++++++ sequence_layers/mlx/dense_test.py | 33 +++ 2 files changed, 389 insertions(+) create mode 100644 sequence_layers/mlx/dense.py create mode 100644 sequence_layers/mlx/dense_test.py diff --git a/sequence_layers/mlx/dense.py b/sequence_layers/mlx/dense.py new file mode 100644 index 0000000..7e8ad52 --- /dev/null +++ b/sequence_layers/mlx/dense.py @@ -0,0 +1,356 @@ +"""Dense sequence layer for MLX.""" + +import dataclasses +from typing import Callable, override + +from mlx import nn +import mlx.core as mx + +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types as mlx_types +from sequence_layers.mlx.init_mapping import _to_mx_dtype +from sequence_layers.specs import dense as spec + + +class _DenseEager(mlx_types.Stateless): + """A basic dense layer backed by mlx.nn.Linear. + + Requires in_features at initialization. + """ + + def __init__( + self, + *, + in_features: int, + features: int, + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + name: str | None = None, + ): + """Initialize _DenseEager.""" + super().__init__() + self.features = features + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + self.name = name + self._linear = nn.Linear(in_features, features, bias=use_bias) + + @property + def use_bias(self): + """Return whether bias is used.""" + return 'bias' in self._linear + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + @override + def get_output_shape(self, input_shape, *, constants=None): + """Get output shape.""" + if not input_shape: + raise ValueError( + f'Dense requires at least rank 3 input. Got: {input_shape=}' + ) + return tuple(input_shape[:-1]) + (self.features,) + + @override + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + @override + @mlx_types.check_layer + def layer(self, x, *, training: bool, constants=None): + compute_dtype = self.get_output_dtype(x.dtype) + + def dense_fn(v): + y = self._linear(v.astype(compute_dtype)) + if self.activation is not None: + y = self.activation(y) + return y + + if self.use_bias or self.activation is not None: + return x.apply_values(dense_fn) + return x.apply_values_masked(dense_fn) + + +class Dense(mlx_types.Stateless, spec.Dense): + """A basic dense layer with deferred initialization. + + Matches JAX interface where in_features is inferred. + """ + + @dataclasses.dataclass(frozen=True) + class Config(mlx_types.SequenceLayerConfig): + """Dense config.""" + + features: int + use_bias: bool = True + activation: Callable | None = None + compute_dtype: mlx_types.DType | None = None + param_dtype: mlx_types.DType = mx.float32 + name: str | None = None + + @override + def make(self) -> 'Dense': + return Dense.from_config(self) + + def __init__( + self, + *, + features: int, + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + name: str | None = None, + ): + """Initialize Dense.""" + super().__init__() + self.features = features + self._use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + self.name = name + self.inner = None + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + def _ensure_initialized(self, in_features: int): + """Ensure inner _DenseEager is initialized.""" + if self.inner is not None: + return + self.inner = _DenseEager( + in_features=in_features, + features=self.features, + use_bias=self._use_bias, + activation=self.activation, + compute_dtype=self.compute_dtype, + param_dtype=self._param_dtype, + name=self.name, + ) + + @override + def get_output_shape(self, input_shape, *, constants=None): + """Get output shape.""" + if not input_shape: + raise ValueError( + f'Dense requires at least rank 3 input. Got: {input_shape=}' + ) + return tuple(input_shape[:-1]) + (self.features,) + + @override + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + @override + @mlx_types.check_layer + def layer(self, x, *, training: bool, constants=None): + self._ensure_initialized(x.shape[-1]) + assert self.inner is not None + return self.inner.layer(x, training=training, constants=constants) + + @classmethod + def from_config(cls, config): + """Create Dense from config.""" + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + return cls( + features=config.features, + use_bias=config.use_bias, + activation=init_mapping.map_activation(config.activation), + compute_dtype=compute_dtype, + param_dtype=_to_mx_dtype(config.param_dtype), + name=config.name, + ) + + +# pylint: disable=too-many-instance-attributes +class EinsumDense(mlx_types.Stateless, spec.EinsumDense): + """Dense layer using Einstein summation notation.""" + + @dataclasses.dataclass(frozen=True) + class Config(mlx_types.SequenceLayerConfig): + """MLX-native configuration for EinsumDense.""" + + equation: str = '' + output_shape: tuple[int | None, ...] = () + bias_axes: str = '' + activation: Callable | None = None + compute_dtype: mlx_types.DType | None = None + param_dtype: mlx_types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + + @override + def make(self) -> 'EinsumDense': + return EinsumDense.from_config(self) + + def __init__( + self, + *, + equation, + output_shape, + bias_axes='', + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + name: str | None = None, + ): + """Initialize EinsumDense.""" + super().__init__() + self._equation = equation + self._output_shape_spec = tuple(output_shape) + self._bias_axes = bias_axes + self._activation = activation + self._compute_dtype = compute_dtype + self._param_dtype = param_dtype + self.name = name + self.kernel = None + self.bias = None + self._initialized = False + self._resolved_output_shape = None + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + def _ensure_initialized(self, input_shape): + """Ensure parameters are initialized.""" + if self._initialized: + return + output_shape, kernel_shape, bias_shape = _compute_shapes( + self._equation, input_shape, self._output_shape_spec, self._bias_axes + ) + self._resolved_output_shape = output_shape + self.kernel = mx.zeros(kernel_shape, dtype=self._param_dtype) + if bias_shape is not None: + self.bias = mx.zeros(bias_shape, dtype=self._param_dtype) + self._initialized = True + + @override + def get_output_shape(self, input_shape, *, constants=None): + """Get output shape.""" + output_shape, _, _ = _compute_shapes( + self._equation, input_shape, self._output_shape_spec, self._bias_axes + ) + return output_shape + + @override + def get_output_dtype(self, input_dtype, *, constants=None): + if self._compute_dtype is not None: + return self._compute_dtype + return self._param_dtype + + @override + @mlx_types.check_layer + def layer(self, x, *, training: bool, constants=None): + self._ensure_initialized(x.channel_shape) + compute_dtype = self.get_output_dtype(x.dtype) + + def einsum_fn(v): + y = mx.einsum(self._equation, v.astype(compute_dtype), self.kernel) + if self.bias is not None: + y = y + self.bias + if self._activation is not None: + y = self._activation(y) + return y + + if self.bias is not None or self._activation is not None: + return x.apply_values(einsum_fn) + return x.apply_values_masked(einsum_fn) + + @classmethod + def from_config(cls, config): + """Create EinsumDense from config.""" + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + return cls( + equation=config.equation, + output_shape=config.output_shape, + bias_axes=config.bias_axes, + activation=init_mapping.map_activation(config.activation), + compute_dtype=compute_dtype, + param_dtype=_to_mx_dtype(config.param_dtype), + name=config.name, + ) + + +def _parse_equation(equation): + """Parse einsum equation of form '...ab,bc->...ac'.""" + if '->' not in equation: + raise ValueError(f'equation is not valid for EinsumDense: {equation}') + left, output_spec = equation.split('->') + input_spec, kernel_spec = left.split(',') + if not input_spec.startswith('...') or not output_spec.startswith('...'): + raise ValueError('Equation must be of the form "...X,Y->...Z".') + if 3 + len(set(input_spec[3:])) != len(input_spec): + raise ValueError( + f'Equation {input_spec=} must not contain duplicate variables.' + ) + if 3 + len(set(output_spec[3:])) != len(output_spec): + raise ValueError( + f'Equation {output_spec=} must not contain duplicate variables.' + ) + return input_spec, kernel_spec, output_spec + + +def _compute_shapes(equation, input_shape, output_shape_spec, bias_axes): + """Compute kernel_shape and bias_shape from equation and shapes.""" + input_spec, kernel_spec, output_spec = _parse_equation(equation) + in_spec = input_spec[3:] + out_spec = output_spec[3:] + + if len(in_spec) != len(input_shape): + raise ValueError(f'Equation {in_spec=} does not match {input_shape=} rank.') + + input_dims = {d: input_shape[i] for i, d in enumerate(in_spec)} + output_shape = list(output_shape_spec) + if len(out_spec) != len(output_shape): + raise ValueError(f'Equation {out_spec=} does not match {output_shape=}.') + + for i, d in enumerate(out_spec): + if output_shape[i] is None: + output_shape[i] = input_dims[d] + elif d in input_dims and output_shape[i] != input_dims[d]: + raise ValueError( + f'Inconsistent dimension {d=}. {output_shape=} vs {input_shape=}' + ) + + output_dim_map = {d: output_shape[i] for i, d in enumerate(out_spec)} + + kernel_shape = [] + for d in kernel_spec: + if d in input_dims: + kernel_shape.append(input_dims[d]) + elif d in output_dim_map: + kernel_shape.append(output_dim_map[d]) + else: + raise ValueError(f"Weight dimension '{d}' not in input or output spec.") + + if bias_axes: + first_bias_loc = min(out_spec.find(c) for c in bias_axes) + bias_out_spec = out_spec[first_bias_loc:] + bias_shape = [ + output_dim_map[c] if c in bias_axes else 1 for c in bias_out_spec + ] + else: + bias_shape = None + + return tuple(output_shape), tuple(kernel_shape), bias_shape diff --git a/sequence_layers/mlx/dense_test.py b/sequence_layers/mlx/dense_test.py new file mode 100644 index 0000000..0cf9055 --- /dev/null +++ b/sequence_layers/mlx/dense_test.py @@ -0,0 +1,33 @@ +"""Tests for Dense MLX sequence layers.""" + +from absl.testing import absltest +from mlx import nn + +from sequence_layers.mlx import dense +from sequence_layers.mlx import test_utils +from sequence_layers.specs import dense_behaviors as spec + + +class DenseTest(test_utils.SequenceLayerTest, spec.DenseTest): + """Test behavior of Dense layer.""" + + def test_eager_layer(self): + """Test DenseEager which is not in the spec.""" + layer = dense._DenseEager(in_features=4, features=8) + x = self.random_sequence(2, 3, 4) + self.verify_contract(layer, x) + + def test_activation(self): + """Test activation in Dense.""" + # Test with deferred Dense + layer = dense.Dense(features=8, activation=nn.relu) + x = self.random_sequence(2, 3, 4) + self.verify_contract(layer, x) + + +class EinsumDenseTest(test_utils.SequenceLayerTest, spec.EinsumDenseTest): + """Test behavior of EinsumDense layer.""" + + +if __name__ == '__main__': + absltest.main() From 3386f08d36cb1f0c4f3db776744a1085e939c0df Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 17 Apr 2026 01:37:22 -0700 Subject: [PATCH 2/5] refactor(specs): Add shared specs and behaviors for Dense --- sequence_layers/specs/__init__.py | 9 ++ sequence_layers/specs/dense.py | 49 +++++++++++ sequence_layers/specs/dense_behaviors.py | 106 +++++++++++++++++++++++ 3 files changed, 164 insertions(+) create mode 100644 sequence_layers/specs/dense.py create mode 100644 sequence_layers/specs/dense_behaviors.py diff --git a/sequence_layers/specs/__init__.py b/sequence_layers/specs/__init__.py index 10df687..c5ee4cd 100644 --- a/sequence_layers/specs/__init__.py +++ b/sequence_layers/specs/__init__.py @@ -5,6 +5,7 @@ from typing import Protocol, runtime_checkable, TYPE_CHECKING from . import backend as _backend +from . import dense as _dense from . import simple as _simple from . import types as _types @@ -116,3 +117,11 @@ def Embedding(self) -> type[_simple.Embedding]: @property def Softmax(self) -> type[_simple.Softmax]: ... + + @property + def Dense(self) -> type[_dense.Dense]: + ... + + @property + def EinsumDense(self) -> type[_dense.EinsumDense]: + ... diff --git a/sequence_layers/specs/dense.py b/sequence_layers/specs/dense.py new file mode 100644 index 0000000..0ebb994 --- /dev/null +++ b/sequence_layers/specs/dense.py @@ -0,0 +1,49 @@ +"""Specifications for dense layers. + +See the corresponding _behaviors module for behaviors. +""" + +import abc +import dataclasses +from typing import Any, Callable, Sequence + +from sequence_layers.specs import types as types_spec + + +class Dense(types_spec.Stateless, metaclass=abc.ABCMeta): + """Specification for Dense layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Dense layer.""" + + features: int + use_bias: bool = True + activation: Callable | None = None + compute_dtype: types_spec.DType | None = None + param_dtype: types_spec.DType | None = None + name: str | None = None + + def make(self) -> Any: + """Dummy make to satisfy Pyrefly.""" + pass + + +class EinsumDense(types_spec.Stateless, metaclass=abc.ABCMeta): + """Specification for EinsumDense layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for EinsumDense layer.""" + + equation: str + output_shape: Sequence[int | None] + bias_axes: str = '' + activation: Callable | None = None + compute_dtype: types_spec.DType | None = None + param_dtype: types_spec.DType | None = None + name: str | None = None + + def make(self) -> Any: + """Dummy make to satisfy Pyrefly.""" + pass diff --git a/sequence_layers/specs/dense_behaviors.py b/sequence_layers/specs/dense_behaviors.py new file mode 100644 index 0000000..e426b7d --- /dev/null +++ b/sequence_layers/specs/dense_behaviors.py @@ -0,0 +1,106 @@ +"""Behavior tests for dense layers. + +Backend-specific test files should inherit from these tests. +""" + +# pylint: disable=abstract-method + +from absl.testing import parameterized + +from sequence_layers.specs import test_utils + + +class DenseTest(test_utils.SequenceLayerTest): + """Test behavior of Dense layer.""" + + def test_rank2_unsupported(self): + l = self.sl.Dense.Config(features=3, name='dense').make() + x = self.random_sequence(2, 13) + with self.assertRaises(ValueError): + self.init_layer(l, x) + + @parameterized.parameters(((5,),), ((5, 7),)) + def test_dense(self, channels_shape): + l = self.sl.Dense.Config(features=3, name='dense').make() + x = self.random_sequence(2, 13, *channels_shape, random_mask=True) + l = self.init_layer(l, x) + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.name, 'dense') + self.assertEqual( + l.get_output_shape_for_sequence(x), channels_shape[:-1] + (3,) + ) + self.verify_contract(l, x, training=False) + + @parameterized.parameters(True, False) + def test_use_bias(self, use_bias): + l = self.sl.Dense.Config(features=3, use_bias=use_bias).make() + x = self.random_sequence(2, 3, 5) + l = self.init_layer(l, x) + self.verify_contract(l, x, training=False) + + +class EinsumDenseTest(test_utils.SequenceLayerTest): + """Test behavior of EinsumDense layer.""" + + @parameterized.parameters( + ( + (2, 3, 5), + '...a,ab->...b', + (7,), + '', + (7,), + ), + ( + (2, 3, 5, 7), + '...ab,ac->...cb', + (11, 7), + 'c', + (11, 7), + ), + ( + (2, 3, 5, 7), + '...ab,b->...a', + (None,), + '', + (5,), + ), + ) + def test_einsum_dense( + self, + shape, + equation, + output_shape, + bias_axes, + expected_output_shape, + ): + x = self.random_sequence(*shape) + l = self.sl.EinsumDense.Config( + equation=equation, + output_shape=output_shape, + bias_axes=bias_axes, + name='einsum_dense', + ).make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.name, 'einsum_dense') + self.assertEqual(l.get_output_shape_for_sequence(x), expected_output_shape) + self.verify_contract(l, x, training=False) + + def test_einsum_dense_nonbroadcasting_equation(self): + with self.assertRaises(ValueError): + x = self.random_sequence(2, 3, 4, 5, 6) + l = self.sl.EinsumDense.Config( + equation='btabc,bc->btad', output_shape=[None, 2] + ).make() + l = self.init_layer(l, x) + + def test_einsum_dense_inconsistent_input_shape(self): + x = self.random_sequence(2, 3, 5) + l = self.sl.EinsumDense.Config( + equation='...abc,bc->...ad', output_shape=[None, 2] + ).make() + with self.assertRaises(ValueError): + self.init_layer(l, x) From 252242deb51064dd74cbd153d88dbbe62b86bbf1 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 17 Apr 2026 11:19:18 -0700 Subject: [PATCH 3/5] refactor(jax): Standardize Dense and tests to use shared specs --- pyproject.toml | 1 + sequence_layers/jax/dense.py | 17 ++++--- sequence_layers/jax/dense_test.py | 81 ++++++++----------------------- 3 files changed, 33 insertions(+), 66 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d9a6672..32b834c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ disable = [ "too-many-positional-arguments", "too-many-public-methods", "too-many-return-statements", + "too-many-instance-attributes", ] diff --git a/sequence_layers/jax/dense.py b/sequence_layers/jax/dense.py index 5fbadb0..7abae4c 100644 --- a/sequence_layers/jax/dense.py +++ b/sequence_layers/jax/dense.py @@ -15,15 +15,16 @@ import dataclasses import typing -from typing import Callable +from typing import Callable, override import flax.linen as nn import jax import jax.numpy as jnp + from sequence_layers.jax import meta from sequence_layers.jax import types from sequence_layers.jax import utils - +from sequence_layers.specs import dense as spec __all__ = ( # go/keep-sorted start @@ -34,11 +35,11 @@ ) -class Dense(types.Stateless, utils.EinsumCommon): +class Dense(types.Stateless, utils.EinsumCommon, spec.Dense): """A basic dense layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.Dense.Config): """Dense config.""" # The number of output features for the dense layer. @@ -73,6 +74,8 @@ class Config(types.SequenceLayerConfig): def make(self) -> 'Dense': return Dense(self, name=self.name) + + config: Config @nn.nowrap @@ -269,7 +272,7 @@ def layer( ) -class EinsumDense(types.Stateless, utils.EinsumCommon): +class EinsumDense(types.Stateless, utils.EinsumCommon, spec.EinsumDense): """A dense layer that transforms the channel shape with an einsum equation. Equation input and output specs must have leading ellipses to broadcast over @@ -291,7 +294,7 @@ class EinsumDense(types.Stateless, utils.EinsumCommon): """ @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.EinsumDense.Config): """EinsumDense config.""" # An equation describing the einsum to perform. This equation must be a @@ -338,6 +341,8 @@ def __post_init__(self): def make(self) -> 'EinsumDense': return EinsumDense(self, name=self.name) + + config: Config @nn.nowrap diff --git a/sequence_layers/jax/dense_test.py b/sequence_layers/jax/dense_test.py index 0edd20b..3820c1f 100644 --- a/sequence_layers/jax/dense_test.py +++ b/sequence_layers/jax/dense_test.py @@ -19,21 +19,14 @@ import flax.linen as nn import jax import jax.numpy as jnp + from sequence_layers.jax import dense from sequence_layers.jax import test_utils from sequence_layers.jax import types +from sequence_layers.specs import dense_behaviors as spec -class DenseTest(test_utils.SequenceLayerTest): - - def test_rank2_unsupported(self): - key = jax.random.PRNGKey(1234) - l = dense.Dense.Config( - 3, bias_init=nn.initializers.normal(), name='dense' - ).make() - x = test_utils.random_sequence(2, 13) - with self.assertRaises(ValueError): - self.init_and_bind_layer(key, l, x) +class DenseTest(test_utils.SequenceLayerTest, spec.DenseTest): @parameterized.parameters(((5,),), ((5, 7),)) def test_dense(self, channels_shape): @@ -49,7 +42,7 @@ def test_dense(self, channels_shape): self.assertEqual( l.get_output_shape_for_sequence(x), channels_shape[:-1] + (3,) ) - self.verify_contract(l, x, training=False, grad_rtol=1e-5, grad_atol=1e-5) + self.verify_contract(l, x, training=False, rtol=1e-5, atol=1e-5, grad_rtol=1e-5, grad_atol=1e-5) chex.assert_trees_all_equal_shapes_and_dtypes( flax.core.meta.unbox(l.variables), @@ -61,17 +54,6 @@ def test_dense(self, channels_shape): }, ) - @parameterized.parameters(True, False) - def test_use_bias(self, use_bias): - """Check that use_bias controls whether a bias is created.""" - key = jax.random.PRNGKey(1234) - l = dense.Dense.Config(3, use_bias=use_bias).make() - x = test_utils.random_sequence(2, 3, 5) - l = self.init_and_bind_layer(key, l, x) - self.assertCountEqual( - l.variables['params'], ['kernel', 'bias'] if use_bias else ['kernel'] - ) - def test_use_einsum_factory(self): """Check that einsum_factory produces is used for dense einsum.""" @@ -254,7 +236,7 @@ def test_dtypes(self, param_dtype, input_dtype, compute_dtype, use_bias): ) -class EinsumDenseTest(test_utils.SequenceLayerTest): +class EinsumDenseTest(test_utils.SequenceLayerTest, spec.EinsumDenseTest): @parameterized.parameters( ( @@ -461,22 +443,22 @@ def custom_einsum(equation, *args, **kwargs): @parameterized.product( test_utils.standard_dtype_configs(), ( - dict( - shape=(2, 3, 5, 7, 11), - equation='...abc,bd->...bd', - output_shape=(None, 13), - expected_kernel_shape=(7, 13), - bias_axes='', - expected_bias_shape=None, - ), - dict( - shape=(2, 3, 5), - equation='...a,abcd->...bcd', - output_shape=(7, 11, 13), - expected_kernel_shape=(5, 7, 11, 13), - bias_axes='cd', - expected_bias_shape=(11, 13), - ), + { + 'shape': (2, 3, 5, 7, 11), + 'equation': '...abc,bd->...bd', + 'output_shape': (None, 13), + 'expected_kernel_shape': (7, 13), + 'bias_axes': '', + 'expected_bias_shape': None, + }, + { + 'shape': (2, 3, 5), + 'equation': '...a,abcd->...bcd', + 'output_shape': (7, 11, 13), + 'expected_kernel_shape': (5, 7, 11, 13), + 'bias_axes': 'cd', + 'expected_bias_shape': (11, 13), + }, ), ) def test_dtypes( @@ -536,27 +518,6 @@ def test_dtypes( ).mask_invalid() self.assertSequencesClose(y, y_expected) - def test_einsum_dense_nonbroadcasting_equation(self): - with self.assertRaises(ValueError): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(2, 3, 4, 5, 6) - l = dense.EinsumDense.Config( - 'btabc,bc->btad', output_shape=[None, 2] - ).make() - self.init_and_bind_layer(key, l, x) - - def test_einsum_dense_inconsistent_input_shape(self): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(2, 3, 5) - l = dense.EinsumDense.Config( - '...abc,bc->...ad', output_shape=[None, 2] - ).make() - with self.assertRaises(ValueError): - self.init_and_bind_layer(key, l, x) - # Show it works with the right input shape. - x = test_utils.random_sequence(2, 3, 5, 7, 11) - self.assertEqual(l.get_output_shape_for_sequence(x), (5, 2)) - if __name__ == '__main__': test_utils.main() From a13bb47b41ab49b009f4dc50a6d4de3530c9fc6f Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 17 Apr 2026 11:27:24 -0700 Subject: [PATCH 4/5] refactor(mlx): Standardize Dense and tests with spec behaviors - Dense and EinsumDense use __init__(config) pattern with deferred init - Removed _DenseEager wrapper, using dynamic submodule creation directly - Added .copy() with @override to backend Config classes - Exported Dense/EinsumDense from mlx/__init__.py - Updated behavior tests for deferred-init compatibility --- sequence_layers/mlx/__init__.py | 6 + sequence_layers/mlx/dense.py | 244 +++++++---------------- sequence_layers/mlx/dense_test.py | 10 +- sequence_layers/specs/dense_behaviors.py | 7 +- 4 files changed, 82 insertions(+), 185 deletions(-) diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 6a17923..4eab10f 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -14,6 +14,7 @@ """Sequence layers in MLX.""" from . import backend +from . import dense from . import simple from . import types # CRITICAL: Do NOT use wildcard imports (e.g., `from .simple import *`) here. @@ -26,6 +27,8 @@ # Explicit imports (e.g., `from .simple import Relu`) DO NOT trigger this issue. # If you need to expose specific layers at the package level, import them # explicitly instead of using a star import. +from .dense import Dense +from .dense import EinsumDense from .simple import Abs from .simple import Add from .simple import Cast @@ -65,6 +68,7 @@ from .types import SequenceLayerConfig __all__ = [ + 'dense', 'backend', 'types', 'simple', @@ -72,6 +76,8 @@ 'MaskedSequence', 'SequenceLayer', 'SequenceLayerConfig', + 'Dense', + 'EinsumDense', 'Identity', 'Relu', 'Gelu', diff --git a/sequence_layers/mlx/dense.py b/sequence_layers/mlx/dense.py index 7e8ad52..2472d76 100644 --- a/sequence_layers/mlx/dense.py +++ b/sequence_layers/mlx/dense.py @@ -6,119 +6,39 @@ from mlx import nn import mlx.core as mx -from sequence_layers.mlx import init_mapping -from sequence_layers.mlx import types as mlx_types -from sequence_layers.mlx.init_mapping import _to_mx_dtype +from sequence_layers.mlx import types +from sequence_layers.mlx.simple import _to_mx_dtype from sequence_layers.specs import dense as spec -class _DenseEager(mlx_types.Stateless): - """A basic dense layer backed by mlx.nn.Linear. - - Requires in_features at initialization. - """ - - def __init__( - self, - *, - in_features: int, - features: int, - use_bias: bool = True, - activation=None, - compute_dtype=None, - param_dtype=mx.float32, - name: str | None = None, - ): - """Initialize _DenseEager.""" - super().__init__() - self.features = features - self.activation = activation - self.compute_dtype = compute_dtype - self._param_dtype = param_dtype - self.name = name - self._linear = nn.Linear(in_features, features, bias=use_bias) - - @property - def use_bias(self): - """Return whether bias is used.""" - return 'bias' in self._linear - - @property - @override - def receptive_field(self) -> tuple[int, int]: - return (0, 0) - - @override - def get_output_shape(self, input_shape, *, constants=None): - """Get output shape.""" - if not input_shape: - raise ValueError( - f'Dense requires at least rank 3 input. Got: {input_shape=}' - ) - return tuple(input_shape[:-1]) + (self.features,) - - @override - def get_output_dtype(self, input_dtype, *, constants=None): - if self.compute_dtype is not None: - return self.compute_dtype - return self._param_dtype - - @override - @mlx_types.check_layer - def layer(self, x, *, training: bool, constants=None): - compute_dtype = self.get_output_dtype(x.dtype) - - def dense_fn(v): - y = self._linear(v.astype(compute_dtype)) - if self.activation is not None: - y = self.activation(y) - return y - - if self.use_bias or self.activation is not None: - return x.apply_values(dense_fn) - return x.apply_values_masked(dense_fn) - - -class Dense(mlx_types.Stateless, spec.Dense): +class Dense(types.Stateless, spec.Dense): """A basic dense layer with deferred initialization. - Matches JAX interface where in_features is inferred. + Matches JAX interface where in_features is inferred on first call. """ @dataclasses.dataclass(frozen=True) - class Config(mlx_types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.Dense.Config): """Dense config.""" features: int use_bias: bool = True activation: Callable | None = None - compute_dtype: mlx_types.DType | None = None - param_dtype: mlx_types.DType = mx.float32 + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 name: str | None = None @override def make(self) -> 'Dense': - return Dense.from_config(self) - - def __init__( - self, - *, - features: int, - use_bias: bool = True, - activation=None, - compute_dtype=None, - param_dtype=mx.float32, - name: str | None = None, - ): + return Dense(self) + + def __init__(self, config: Config): """Initialize Dense.""" super().__init__() - self.features = features - self._use_bias = use_bias - self.activation = activation - self.compute_dtype = compute_dtype - self._param_dtype = param_dtype - self.name = name - self.inner = None + self.config = config + self._compute_dtype = _to_mx_dtype(config.compute_dtype) + self._param_dtype = _to_mx_dtype(config.param_dtype) + self._linear = None @property @override @@ -126,17 +46,11 @@ def receptive_field(self) -> tuple[int, int]: return (0, 0) def _ensure_initialized(self, in_features: int): - """Ensure inner _DenseEager is initialized.""" - if self.inner is not None: + """Ensure nn.Linear is initialized on first call.""" + if self._linear is not None: return - self.inner = _DenseEager( - in_features=in_features, - features=self.features, - use_bias=self._use_bias, - activation=self.activation, - compute_dtype=self.compute_dtype, - param_dtype=self._param_dtype, - name=self.name, + self._linear = nn.Linear( + in_features, self.config.features, bias=self.config.use_bias ) @override @@ -146,51 +60,51 @@ def get_output_shape(self, input_shape, *, constants=None): raise ValueError( f'Dense requires at least rank 3 input. Got: {input_shape=}' ) - return tuple(input_shape[:-1]) + (self.features,) + return tuple(input_shape[:-1]) + (self.config.features,) @override def get_output_dtype(self, input_dtype, *, constants=None): - if self.compute_dtype is not None: - return self.compute_dtype + if self._compute_dtype is not None: + return self._compute_dtype + assert self._param_dtype is not None return self._param_dtype @override - @mlx_types.check_layer - def layer(self, x, *, training: bool, constants=None): + @types.check_layer + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + if x.ndim < 3: + raise ValueError(f'Dense requires at least rank 3 input. Got: {x.shape=}') self._ensure_initialized(x.shape[-1]) - assert self.inner is not None - return self.inner.layer(x, training=training, constants=constants) - - @classmethod - def from_config(cls, config): - """Create Dense from config.""" - compute_dtype = getattr(config, 'compute_dtype', None) - if compute_dtype is not None: - compute_dtype = _to_mx_dtype(compute_dtype) - return cls( - features=config.features, - use_bias=config.use_bias, - activation=init_mapping.map_activation(config.activation), - compute_dtype=compute_dtype, - param_dtype=_to_mx_dtype(config.param_dtype), - name=config.name, - ) + assert self._linear is not None + activation = self.config.activation + compute_dtype = self.get_output_dtype(x.dtype) + def dense_fn(v): + y = self._linear(v.astype(compute_dtype)) + if activation is not None: + y = activation(y) + return y + + if self.config.use_bias or activation is not None: + return x.apply_values(dense_fn) + return x.apply_values_masked(dense_fn) -# pylint: disable=too-many-instance-attributes -class EinsumDense(mlx_types.Stateless, spec.EinsumDense): + +class EinsumDense(types.Stateless, spec.EinsumDense): """Dense layer using Einstein summation notation.""" @dataclasses.dataclass(frozen=True) - class Config(mlx_types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.EinsumDense.Config): """MLX-native configuration for EinsumDense.""" equation: str = '' output_shape: tuple[int | None, ...] = () bias_axes: str = '' activation: Callable | None = None - compute_dtype: mlx_types.DType | None = None - param_dtype: mlx_types.DType = mx.float32 + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 name: str | None = None def __post_init__(self): @@ -198,28 +112,14 @@ def __post_init__(self): @override def make(self) -> 'EinsumDense': - return EinsumDense.from_config(self) - - def __init__( - self, - *, - equation, - output_shape, - bias_axes='', - activation=None, - compute_dtype=None, - param_dtype=mx.float32, - name: str | None = None, - ): + return EinsumDense(self) + + def __init__(self, config: Config): """Initialize EinsumDense.""" super().__init__() - self._equation = equation - self._output_shape_spec = tuple(output_shape) - self._bias_axes = bias_axes - self._activation = activation - self._compute_dtype = compute_dtype - self._param_dtype = param_dtype - self.name = name + self.config = config + self._compute_dtype = _to_mx_dtype(config.compute_dtype) + self._param_dtype = _to_mx_dtype(config.param_dtype) self.kernel = None self.bias = None self._initialized = False @@ -235,7 +135,10 @@ def _ensure_initialized(self, input_shape): if self._initialized: return output_shape, kernel_shape, bias_shape = _compute_shapes( - self._equation, input_shape, self._output_shape_spec, self._bias_axes + self.config.equation, + input_shape, + self.config.output_shape, + self.config.bias_axes, ) self._resolved_output_shape = output_shape self.kernel = mx.zeros(kernel_shape, dtype=self._param_dtype) @@ -247,7 +150,10 @@ def _ensure_initialized(self, input_shape): def get_output_shape(self, input_shape, *, constants=None): """Get output shape.""" output_shape, _, _ = _compute_shapes( - self._equation, input_shape, self._output_shape_spec, self._bias_axes + self.config.equation, + input_shape, + self.config.output_shape, + self.config.bias_axes, ) return output_shape @@ -255,42 +161,30 @@ def get_output_shape(self, input_shape, *, constants=None): def get_output_dtype(self, input_dtype, *, constants=None): if self._compute_dtype is not None: return self._compute_dtype + assert self._param_dtype is not None return self._param_dtype @override - @mlx_types.check_layer - def layer(self, x, *, training: bool, constants=None): + @types.check_layer + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): self._ensure_initialized(x.channel_shape) compute_dtype = self.get_output_dtype(x.dtype) + activation = self.config.activation def einsum_fn(v): - y = mx.einsum(self._equation, v.astype(compute_dtype), self.kernel) + y = mx.einsum(self.config.equation, v.astype(compute_dtype), self.kernel) if self.bias is not None: y = y + self.bias - if self._activation is not None: - y = self._activation(y) + if activation is not None: + y = activation(y) return y - if self.bias is not None or self._activation is not None: + if self.bias is not None or activation is not None: return x.apply_values(einsum_fn) return x.apply_values_masked(einsum_fn) - @classmethod - def from_config(cls, config): - """Create EinsumDense from config.""" - compute_dtype = getattr(config, 'compute_dtype', None) - if compute_dtype is not None: - compute_dtype = _to_mx_dtype(compute_dtype) - return cls( - equation=config.equation, - output_shape=config.output_shape, - bias_axes=config.bias_axes, - activation=init_mapping.map_activation(config.activation), - compute_dtype=compute_dtype, - param_dtype=_to_mx_dtype(config.param_dtype), - name=config.name, - ) - def _parse_equation(equation): """Parse einsum equation of form '...ab,bc->...ac'.""" diff --git a/sequence_layers/mlx/dense_test.py b/sequence_layers/mlx/dense_test.py index 0cf9055..a323f24 100644 --- a/sequence_layers/mlx/dense_test.py +++ b/sequence_layers/mlx/dense_test.py @@ -11,17 +11,11 @@ class DenseTest(test_utils.SequenceLayerTest, spec.DenseTest): """Test behavior of Dense layer.""" - def test_eager_layer(self): - """Test DenseEager which is not in the spec.""" - layer = dense._DenseEager(in_features=4, features=8) - x = self.random_sequence(2, 3, 4) - self.verify_contract(layer, x) - def test_activation(self): """Test activation in Dense.""" - # Test with deferred Dense - layer = dense.Dense(features=8, activation=nn.relu) + layer = dense.Dense.Config(features=8, activation=nn.relu).make() x = self.random_sequence(2, 3, 4) + layer = self.init_layer(layer, x) self.verify_contract(layer, x) diff --git a/sequence_layers/specs/dense_behaviors.py b/sequence_layers/specs/dense_behaviors.py index e426b7d..87a2773 100644 --- a/sequence_layers/specs/dense_behaviors.py +++ b/sequence_layers/specs/dense_behaviors.py @@ -17,7 +17,8 @@ def test_rank2_unsupported(self): l = self.sl.Dense.Config(features=3, name='dense').make() x = self.random_sequence(2, 13) with self.assertRaises(ValueError): - self.init_layer(l, x) + l = self.init_layer(l, x) + l.layer(x, training=False) @parameterized.parameters(((5,),), ((5, 7),)) def test_dense(self, channels_shape): @@ -96,6 +97,7 @@ def test_einsum_dense_nonbroadcasting_equation(self): equation='btabc,bc->btad', output_shape=[None, 2] ).make() l = self.init_layer(l, x) + l.layer(x, training=False) def test_einsum_dense_inconsistent_input_shape(self): x = self.random_sequence(2, 3, 5) @@ -103,4 +105,5 @@ def test_einsum_dense_inconsistent_input_shape(self): equation='...abc,bc->...ad', output_shape=[None, 2] ).make() with self.assertRaises(ValueError): - self.init_layer(l, x) + l = self.init_layer(l, x) + l.layer(x, training=False) From ab968efc3675af997830e23c9b048a85992bef2b Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Sat, 18 Apr 2026 22:30:02 -0700 Subject: [PATCH 5/5] chore: Remove unnecessary pass statements in specs/dense.py --- sequence_layers/specs/dense.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sequence_layers/specs/dense.py b/sequence_layers/specs/dense.py index 0ebb994..ec5e10d 100644 --- a/sequence_layers/specs/dense.py +++ b/sequence_layers/specs/dense.py @@ -26,7 +26,6 @@ class Config(types_spec.SequenceLayerConfig): def make(self) -> Any: """Dummy make to satisfy Pyrefly.""" - pass class EinsumDense(types_spec.Stateless, metaclass=abc.ABCMeta): @@ -46,4 +45,3 @@ class Config(types_spec.SequenceLayerConfig): def make(self) -> Any: """Dummy make to satisfy Pyrefly.""" - pass