Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ disable = [
"too-many-statements",
"too-many-branches",
"too-many-positional-arguments",
"too-many-public-methods",
"too-many-return-statements",
]


Expand All @@ -114,4 +116,7 @@ exclude = [
]

[tool.pyrefly]
errors = { missing-override-decorator = "error" }
# Pyrefly fails to properly support config: Config without defaults in Flax Modules
# (used in JAX), incorrectly treating them as dataclasses and complaining about
# field ordering. This effectively only impacts JAX files.
errors = { missing-override-decorator = "error", bad-class-definition = "ignore" }
53 changes: 53 additions & 0 deletions sequence_layers/jax/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import override

import jax.nn as jnn
import jax.numpy as jnp

from sequence_layers.specs import backend as spec
Expand All @@ -23,8 +24,60 @@ def array(self, a, dtype=None) -> types_spec.Array:
def zeros(self, shape, dtype=None) -> types_spec.Array:
return jnp.zeros(shape, dtype=dtype)

@override
def concatenate(self, arrays, axis=0) -> types_spec.Array:
return jnp.concatenate(arrays, axis=axis)

@override
def abs(self, x) -> types_spec.Array:
return jnp.abs(x)

@override
def exp(self, x) -> types_spec.Array:
return jnp.exp(x)

@override
def log(self, x) -> types_spec.Array:
return jnp.log(x)


xp: spec.xp = BackendWrapper()


class NNWrapper(spec.nn):
"""Wrapper around JAX activations to match backend protocol."""

@override
def relu(self, x: types_spec.Array) -> types_spec.Array:
return jnn.relu(x)

@override
def sigmoid(self, x: types_spec.Array) -> types_spec.Array:
return jnn.sigmoid(x)

@override
def tanh(self, x: types_spec.Array) -> types_spec.Array:
return jnn.tanh(x)

@override
def swish(self, x: types_spec.Array) -> types_spec.Array:
return jnn.swish(x)

@override
def gelu(self, x: types_spec.Array) -> types_spec.Array:
return jnn.gelu(x)

@override
def elu(self, x: types_spec.Array) -> types_spec.Array:
return jnn.elu(x)

@override
def softplus(self, x: types_spec.Array) -> types_spec.Array:
return jnn.softplus(x)

@override
def softmax(self, x: types_spec.Array, axis: int = -1) -> types_spec.Array:
return jnn.softmax(x, axis=axis)


nn: spec.nn = NNWrapper()
Loading