Skip to content

Handle the new JAX bind-params API#42

Open
gcarleo wants to merge 1 commit intomicrosoft:mainfrom
netket:codex/fix-get-bind-params-upstream
Open

Handle the new JAX bind-params API#42
gcarleo wants to merge 1 commit intomicrosoft:mainfrom
netket:codex/fix-get-bind-params-upstream

Conversation

@gcarleo
Copy link

@gcarleo gcarleo commented Mar 20, 2026

Summary

  • normalize primitive.get_bind_params(...) in the interpreter
  • gate the compatibility behavior on jax >= 0.9.2
  • add regression tests covering both the legacy tuple-return API and the new dict-return API

Root cause

For newer JAX versions, some primitives such as integer_pow, lt, and jit return a plain dict from primitive.get_bind_params(...) instead of (subfuns, params). folx still unpacked the result unconditionally, which caused failures like ValueError: not enough values to unpack.

Validation

  • uv run --python 3.13 --no-dev --with pytest --with pytest-xdist --with jax==0.9.2 --with jaxlib==0.9.2 pytest test/test_interpreter.py test/test_customjvp.py test/test_operator.py

This run passed locally on the clean upstream-only branch.

@gcarleo
Copy link
Author

gcarleo commented Mar 20, 2026

@n-gao this is a small fix needed with the latest jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant