Skip to content

Fix jax dependency#100

Closed
TheodoreWolf wants to merge 1 commit intoRobertTLange:mainfrom
TheodoreWolf:theo/fix-dependencies
Closed

Fix jax dependency#100
TheodoreWolf wants to merge 1 commit intoRobertTLange:mainfrom
TheodoreWolf:theo/fix-dependencies

Conversation

@TheodoreWolf
Copy link
Copy Markdown
Contributor

Hello (again),

Currently your pyproject.toml claims to support jax>=0.5.0, but this is incorrect since 0.6.0, that has deprecated jax.tree_map. Downloading the current dependencies and running basically anything results in:

AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).

I changed the pyproject.toml to specify JAX<0.6.0.

@TheodoreWolf
Copy link
Copy Markdown
Contributor Author

Other solution would be to replace all instances of jax.tree_map module with tree. Let me know if you'd prefer that.

@RobertTLange
Copy link
Copy Markdown
Owner

Thanks for flagging this.

I agree with the underlying issue, but I think pinning jax<0.6.0 is not the right long-term fix for evosax. Since we already require jax>=0.5.0, I’d rather keep the project compatible with the newer tree API and use jax.tree.map / jax.tree_util.tree_map instead of narrowing the dependency range.

main already includes that migration now, and the current dependency constraint is jax>=0.5.0,<0.7, so this PR has effectively been superseded.

I’m going to close this as stale, but thanks again for catching the breakage and opening the PR. If you run into any remaining JAX 0.6+ incompatibilities, a follow-up PR with the specific failing case would be very helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants