Skip to content

Thread PRNG key through wrap_jax_jit to prevent UnexpectedTracerError#95

Open
ifding wants to merge 1 commit into
google:mainfrom
ifding:fix-prng-tracer-leak-jax-jit
Open

Thread PRNG key through wrap_jax_jit to prevent UnexpectedTracerError#95
ifding wants to merge 1 commit into
google:mainfrom
ifding:fix-prng-tracer-leak-jax-jit

Conversation

@ifding
Copy link
Copy Markdown

@ifding ifding commented May 12, 2026

Random aten ops (dropout, bernoulli, randn, …) call env.get_and_rotate_prng_key() which mutates RuntimeProperty.prng to a DynamicJaxprTracer while jit-tracing. The tracer escapes the trace and contaminates env.prng_key; subsequent reads return a tracer instead of a usable key, and any later jax.jit invocation that re-traces and references the leaked tracer raises
jax.errors.UnexpectedTracerError.

This fixes the leak by threading the PRNG key through the jit boundary as the last positional input and trailing output. Random ops rotate on a scoped RuntimeProperty (popped on trace exit), the rotated key is returned as a jit output, and the env is refreshed outside jit so subsequent calls see the advanced key. The fix is transparent to callers — wrap_jax_jit's signature is unchanged.

Also adds a prng_key setter (the only mutation path was previously manual_seed) so wrap_jax_jit can refresh the env after each jit call.

Closes #17.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 12, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Random aten ops (dropout, bernoulli, randn, …) call
env.get_and_rotate_prng_key() which mutates RuntimeProperty.prng to a
DynamicJaxprTracer while jit-tracing. The tracer escapes the trace
and contaminates env.prng_key; subsequent reads return a tracer
instead of a usable key, and any later jax.jit invocation that
re-traces and references the leaked tracer raises
jax.errors.UnexpectedTracerError.

This fixes the leak by threading the PRNG key through the jit
boundary as the last positional input and trailing output. Random
ops rotate on a scoped RuntimeProperty (popped on trace exit), the
rotated key is returned as a jit output, and the env is refreshed
outside jit so subsequent calls see the advanced key. The fix is
transparent to callers — wrap_jax_jit's signature is unchanged.

Also adds a prng_key setter (the only mutation path was previously
manual_seed) so wrap_jax_jit can refresh the env after each jit
call.

Closes google#17.

Signed-off-by: Fei Ding <feid@meta.com>
@ifding ifding force-pushed the fix-prng-tracer-leak-jax-jit branch from 5df28c9 to 82f1ee8 Compare May 12, 2026 22:52
@ifding
Copy link
Copy Markdown
Author

ifding commented May 12, 2026

Hi torchax maintainers — this is my first contribution, so the CI checks need
to be approved manually. Could someone with write access click "Approve and
run workflows"? The PR closes #17 and adds 4 regression tests covering the
fix.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

How do RNGs work in torchax

1 participant