Skip to content

fix: remove retired jax_enable_async_collective_offload flag from TPU launch script#119

Open
aryanputta wants to merge 1 commit into
young-geng:mainfrom
aryanputta:fix/remove-retired-jax-async-collective-offload-flag
Open

fix: remove retired jax_enable_async_collective_offload flag from TPU launch script#119
aryanputta wants to merge 1 commit into
young-geng:mainfrom
aryanputta:fix/remove-retired-jax-async-collective-offload-flag

Conversation

@aryanputta

Copy link
Copy Markdown

Problem

Running examples/pretrain_llama_7b.sh on TPU pods with recent JAX versions fails immediately with:

ERROR: Accessing retired flag 'jax_enable_async_collective_offload'

The flag --jax_enable_async_collective_offload=true in LIBTPU_INIT_ARGS was retired upstream. JAX now manages async collective offload behavior internally and does not accept this flag as an explicit argument.

Reported in #109.

Fix

Remove --jax_enable_async_collective_offload=true from LIBTPU_INIT_ARGS in examples/pretrain_llama_7b.sh. All other XLA/TPU flags are retained unchanged.

Test

Reproduced the error on a v3 TPU pod with JAX >= 0.4.26. After this change the ERROR: Accessing retired flag message is gone and training proceeds normally.

… launch script

The --jax_enable_async_collective_offload flag was retired in recent JAX
releases. Accessing it triggers:

  ERROR: Accessing retired flag 'jax_enable_async_collective_offload'

which aborts TPU pod training before the first step. This is the
root cause reported in issue young-geng#109.

Remove the flag from LIBTPU_INIT_ARGS in examples/pretrain_llama_7b.sh.
The async collective offload behaviour it controlled is now managed
internally by the XLA/JAX runtime and does not need explicit opt-in.

Fixes young-geng#109
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant