Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions folx/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
import re
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, ParamSpec, Sequence, TypeVar
Expand Down Expand Up @@ -44,6 +45,36 @@
P = ParamSpec('P')


def _version_key(version: str) -> tuple[int, ...]:
return tuple(
int(match.group(0))
for part in version.split('.')
if (match := re.match(r'\d+', part)) is not None
)


_USES_DICT_BIND_PARAMS = _version_key(jax.__version__) >= (0, 9, 2)


def _split_bind_params(primitive, params):
"""Normalize JAX primitive bind params across old and new JAX APIs."""

bind_params = primitive.get_bind_params(params)
if _USES_DICT_BIND_PARAMS:
if not isinstance(bind_params, dict):
raise TypeError(
'Expected primitive.get_bind_params() to return a dict for '
f'JAX >= 0.9.2, got {bind_params!r}.'
)
return (), bind_params
if not isinstance(bind_params, tuple) or len(bind_params) != 2:
raise TypeError(
'Expected primitive.get_bind_params() to return a 2-tuple for '
f'JAX < 0.9.2, got {bind_params!r}.'
)
return bind_params


class JaxExprEnvironment:
# A simple environment that keeps track of the variables
# and frees them once they are no longer needed.
Expand Down Expand Up @@ -169,15 +200,15 @@ def eval_pjit(eqn: JaxprEqn, invals):
)

def eval_custom_jvp(eqn: JaxprEqn, invals):
subfuns, args = eqn.primitive.get_bind_params(eqn.params)
subfuns, args = _split_bind_params(eqn.primitive, eqn.params)
fn = functools.partial(eqn.primitive.bind, *subfuns, **args)
with LoggingPrefix(f'({summarize(eqn.source_info)})'):
return wrap_forward_laplacian(fn)(
invals, {}, sparsity_threshold=sparsity_threshold
)

def eval_laplacian(eqn: JaxprEqn, invals):
subfuns, params = eqn.primitive.get_bind_params(eqn.params)
subfuns, params = _split_bind_params(eqn.primitive, eqn.params)
with LoggingPrefix(f'({summarize(eqn.source_info)})'):
fn = get_laplacian(eqn.primitive, True)
return fn(
Expand All @@ -189,7 +220,7 @@ def eval_laplacian(eqn: JaxprEqn, invals):
# Eval expression
try:
if all(not isinstance(x, FwdLaplArray) for x in invals):
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
subfuns, bind_params = _split_bind_params(eqn.primitive, eqn.params)
# If non of the inputs were dependent on an FwdLaplArray,
# we can just use the regular primitive. This will avoid
# omnistaging. While this could cost us some memory and speed,
Expand Down
48 changes: 48 additions & 0 deletions test/test_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

import folx.interpreter as interpreter_mod


class _DummyPrimitive:
def __init__(self, result):
self.result = result

def get_bind_params(self, params):
assert params == {'x': 1}
return self.result


def test_split_bind_params_accepts_legacy_tuple(monkeypatch):
monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', False)
subfuns, params = interpreter_mod._split_bind_params(
_DummyPrimitive((['fn'], {'x': 2})), {'x': 1}
)
assert subfuns == ['fn']
assert params == {'x': 2}


def test_split_bind_params_accepts_new_dict_only_api(monkeypatch):
monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', True)
subfuns, params = interpreter_mod._split_bind_params(
_DummyPrimitive({'x': 2}), {'x': 1}
)
assert subfuns == ()
assert params == {'x': 2}


def test_split_bind_params_rejects_invalid_tuple_shape(monkeypatch):
monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', False)
with pytest.raises(TypeError, match='2-tuple'):
interpreter_mod._split_bind_params(_DummyPrimitive(({'x': 2},)), {'x': 1})


def test_split_bind_params_rejects_dict_on_legacy_jax(monkeypatch):
monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', False)
with pytest.raises(TypeError, match='JAX < 0.9.2'):
interpreter_mod._split_bind_params(_DummyPrimitive({'x': 2}), {'x': 1})


def test_split_bind_params_rejects_tuple_on_new_jax(monkeypatch):
monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', True)
with pytest.raises(TypeError, match='JAX >= 0.9.2'):
interpreter_mod._split_bind_params(_DummyPrimitive((['fn'], {'x': 2})), {'x': 1})