Skip to content

Updates for compatibility in RL tutorial with JAX v0.6.0#89

Merged
maxencefaldor merged 1 commit intoRobertTLange:mainfrom
mKabouri:main
May 8, 2025
Merged

Updates for compatibility in RL tutorial with JAX v0.6.0#89
maxencefaldor merged 1 commit intoRobertTLange:mainfrom
mKabouri:main

Conversation

@mKabouri
Copy link
Copy Markdown
Contributor

While running the example notebook examples/02_rl.ipynb, I encountered an error due to the removal of jax.tree_map in JAX v0.6.0. I replaced calls to jax.tree_map with jax.tree.map

The cell where I had the error:

from evosax.problems import BraxProblem as Problem
from evosax.problems.networks import MLP, tanh_output_fn

policy = MLP(
    layer_sizes=(32, 32, 32, 32, 8),
    output_fn=tanh_output_fn,
)

problem = Problem(
    env_name="ant",
    policy=policy,
    episode_length=1000,
    num_rollouts=16,
    use_normalize_obs=True,
)

key, subkey = jax.random.split(key)
problem_state = problem.init(key)

key, subkey = jax.random.split(key)
solution = problem.sample(subkey)

The error:

AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).

@mKabouri mKabouri changed the title Updates for compatibility with JAX v0.6.0 Updates for compatibility in RL tutorial with JAX v0.6.0 May 6, 2025
@maxencefaldor maxencefaldor merged commit 5853e15 into RobertTLange:main May 8, 2025
0 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants