Description
The create_zeros JIT function inside _single_replica_deserialize_and_broadcast requires jax.set_mesh(local_mesh) to scope the JIT-compiled zeros allocation to replica-local devices. Without this mesh context, the JIT compilation conflicts with the active training mesh, causing an incompatible devices error on the non-primary code path.
Related issue: #3164
Environment
- JAX 0.9, multi-host GPU (4 workers × 8 H200 GPUs = 32 devices)
- Orbax 0.11.33
- Mesh shape: pp=1, dp=4, tp=8
- Single-replica broadcast checkpoint restore
Error
Without jax.set_mesh(local_mesh) around create_zeros:
[rank3]: ValueError: Received incompatible devices for jitted computation. Got jit's context mesh with device ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] on platform GPU and explicit output sharding with device ids [24, 25, 26, 27, 28, 29, 30, 31] on platform GPU
Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py#L1411
In _single_replica_deserialize_and_broadcast, the non-primary code path:
@functools.partial(
jax.jit, static_argnums=0, out_shardings=tuple(single_replica_shardings)
)
def create_zeros(shape_dtype_tup):
return jax.tree.map(
lambda sd: jnp.zeros(sd.shape, dtype=sd.dtype), shape_dtype_tup
)
shape_dtype = [
jax.ShapeDtypeStruct(arg.global_shape, arg.dtype) for arg in args
]
local_mesh = single_replica_shardings[0].mesh
with jax.set_mesh(local_mesh):
deserialized = create_zeros(tuple(shape_dtype))
The create_zeros JIT has out_shardings targeting replica-local devices (e.g., devices [24–31] for a non-primary replica with tp=8). When the active training mesh spans all 32 devices, JAX rejects the compilation because the output sharding's device set doesn't match the context mesh's device set.
jax.set_mesh(local_mesh) is required to override the context mesh to the single-replica mesh, making the output shardings compatible. Without it, the function fails.
Suggested Fix
Add jax.set_mesh(local_mesh) around the create_zeros call to scope the JIT to the replica-local device set:
local_mesh = single_replica_shardings[0].mesh
with jax.set_mesh(local_mesh):
deserialized = create_zeros(tuple(shape_dtype))
Note: jax.set_mesh(local_mesh) is required. Using with local_mesh would not override global mesh correctly.
Related issue: #3164
Description
The create_zeros JIT function inside
_single_replica_deserialize_and_broadcastrequiresjax.set_mesh(local_mesh)to scope the JIT-compiled zeros allocation to replica-local devices. Without this mesh context, the JIT compilation conflicts with the active training mesh, causing an incompatible devices error on the non-primary code path.Related issue: #3164
Environment
Error
Without
jax.set_mesh(local_mesh)aroundcreate_zeros:Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py#L1411
In _single_replica_deserialize_and_broadcast, the non-primary code path:
The create_zeros JIT has out_shardings targeting replica-local devices (e.g., devices [24–31] for a non-primary replica with tp=8). When the active training mesh spans all 32 devices, JAX rejects the compilation because the output sharding's device set doesn't match the context mesh's device set.
jax.set_mesh(local_mesh)is required to override the context mesh to the single-replica mesh, making the output shardings compatible. Without it, the function fails.Suggested Fix
Add
jax.set_mesh(local_mesh)around thecreate_zeroscall to scope the JIT to the replica-local device set:Note:
jax.set_mesh(local_mesh)is required. Usingwith local_meshwould not override global mesh correctly.Related issue: #3164