From a30e2237f0a1077336bcf7e37b57332c497e1f2d Mon Sep 17 00:00:00 2001 From: Giuseppe Carleo <28149892+gcarleo@users.noreply.github.com> Date: Thu, 19 Mar 2026 20:38:01 +0100 Subject: [PATCH] Handle new JAX bind params API --- folx/interpreter.py | 37 ++++++++++++++++++++++++++++--- test/test_interpreter.py | 48 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 test/test_interpreter.py diff --git a/folx/interpreter.py b/folx/interpreter.py index 3cc9380..db5f03e 100644 --- a/folx/interpreter.py +++ b/folx/interpreter.py @@ -1,5 +1,6 @@ import functools import logging +import re import warnings from collections import defaultdict from typing import TYPE_CHECKING, Callable, ParamSpec, Sequence, TypeVar @@ -44,6 +45,36 @@ P = ParamSpec('P') +def _version_key(version: str) -> tuple[int, ...]: + return tuple( + int(match.group(0)) + for part in version.split('.') + if (match := re.match(r'\d+', part)) is not None + ) + + +_USES_DICT_BIND_PARAMS = _version_key(jax.__version__) >= (0, 9, 2) + + +def _split_bind_params(primitive, params): + """Normalize JAX primitive bind params across old and new JAX APIs.""" + + bind_params = primitive.get_bind_params(params) + if _USES_DICT_BIND_PARAMS: + if not isinstance(bind_params, dict): + raise TypeError( + 'Expected primitive.get_bind_params() to return a dict for ' + f'JAX >= 0.9.2, got {bind_params!r}.' + ) + return (), bind_params + if not isinstance(bind_params, tuple) or len(bind_params) != 2: + raise TypeError( + 'Expected primitive.get_bind_params() to return a 2-tuple for ' + f'JAX < 0.9.2, got {bind_params!r}.' + ) + return bind_params + + class JaxExprEnvironment: # A simple environment that keeps track of the variables # and frees them once they are no longer needed. @@ -169,7 +200,7 @@ def eval_pjit(eqn: JaxprEqn, invals): ) def eval_custom_jvp(eqn: JaxprEqn, invals): - subfuns, args = eqn.primitive.get_bind_params(eqn.params) + subfuns, args = _split_bind_params(eqn.primitive, eqn.params) fn = functools.partial(eqn.primitive.bind, *subfuns, **args) with LoggingPrefix(f'({summarize(eqn.source_info)})'): return wrap_forward_laplacian(fn)( @@ -177,7 +208,7 @@ def eval_custom_jvp(eqn: JaxprEqn, invals): ) def eval_laplacian(eqn: JaxprEqn, invals): - subfuns, params = eqn.primitive.get_bind_params(eqn.params) + subfuns, params = _split_bind_params(eqn.primitive, eqn.params) with LoggingPrefix(f'({summarize(eqn.source_info)})'): fn = get_laplacian(eqn.primitive, True) return fn( @@ -189,7 +220,7 @@ def eval_laplacian(eqn: JaxprEqn, invals): # Eval expression try: if all(not isinstance(x, FwdLaplArray) for x in invals): - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + subfuns, bind_params = _split_bind_params(eqn.primitive, eqn.params) # If non of the inputs were dependent on an FwdLaplArray, # we can just use the regular primitive. This will avoid # omnistaging. While this could cost us some memory and speed, diff --git a/test/test_interpreter.py b/test/test_interpreter.py new file mode 100644 index 0000000..ea0995a --- /dev/null +++ b/test/test_interpreter.py @@ -0,0 +1,48 @@ +import pytest + +import folx.interpreter as interpreter_mod + + +class _DummyPrimitive: + def __init__(self, result): + self.result = result + + def get_bind_params(self, params): + assert params == {'x': 1} + return self.result + + +def test_split_bind_params_accepts_legacy_tuple(monkeypatch): + monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', False) + subfuns, params = interpreter_mod._split_bind_params( + _DummyPrimitive((['fn'], {'x': 2})), {'x': 1} + ) + assert subfuns == ['fn'] + assert params == {'x': 2} + + +def test_split_bind_params_accepts_new_dict_only_api(monkeypatch): + monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', True) + subfuns, params = interpreter_mod._split_bind_params( + _DummyPrimitive({'x': 2}), {'x': 1} + ) + assert subfuns == () + assert params == {'x': 2} + + +def test_split_bind_params_rejects_invalid_tuple_shape(monkeypatch): + monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', False) + with pytest.raises(TypeError, match='2-tuple'): + interpreter_mod._split_bind_params(_DummyPrimitive(({'x': 2},)), {'x': 1}) + + +def test_split_bind_params_rejects_dict_on_legacy_jax(monkeypatch): + monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', False) + with pytest.raises(TypeError, match='JAX < 0.9.2'): + interpreter_mod._split_bind_params(_DummyPrimitive({'x': 2}), {'x': 1}) + + +def test_split_bind_params_rejects_tuple_on_new_jax(monkeypatch): + monkeypatch.setattr(interpreter_mod, '_USES_DICT_BIND_PARAMS', True) + with pytest.raises(TypeError, match='JAX >= 0.9.2'): + interpreter_mod._split_bind_params(_DummyPrimitive((['fn'], {'x': 2})), {'x': 1})