Skip to content

create_zeros in _single_replica_deserialize_and_broadcast needs jax.set_mesh with local replica mesh to avoid incompatible devices error #3184

@evelyn22chen

Description

@evelyn22chen

Description

The create_zeros JIT function inside _single_replica_deserialize_and_broadcast requires jax.set_mesh(local_mesh) to scope the JIT-compiled zeros allocation to replica-local devices. Without this mesh context, the JIT compilation conflicts with the active training mesh, causing an incompatible devices error on the non-primary code path.

Related issue: #3164

Environment

  • JAX 0.9, multi-host GPU (4 workers × 8 H200 GPUs = 32 devices)
  • Orbax 0.11.33
  • Mesh shape: pp=1, dp=4, tp=8
  • Single-replica broadcast checkpoint restore

Error

Without jax.set_mesh(local_mesh) around create_zeros:

[rank3]: ValueError: Received incompatible devices for jitted computation. Got jit's context mesh with device ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] on platform GPU and explicit output sharding with device ids [24, 25, 26, 27, 28, 29, 30, 31] on platform GPU

Root Cause

https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py#L1411

In _single_replica_deserialize_and_broadcast, the non-primary code path:

@functools.partial(
    jax.jit, static_argnums=0, out_shardings=tuple(single_replica_shardings)
)
def create_zeros(shape_dtype_tup):
  return jax.tree.map(
      lambda sd: jnp.zeros(sd.shape, dtype=sd.dtype), shape_dtype_tup
  )

shape_dtype = [
    jax.ShapeDtypeStruct(arg.global_shape, arg.dtype) for arg in args
]
local_mesh = single_replica_shardings[0].mesh
with jax.set_mesh(local_mesh):
  deserialized = create_zeros(tuple(shape_dtype))

The create_zeros JIT has out_shardings targeting replica-local devices (e.g., devices [24–31] for a non-primary replica with tp=8). When the active training mesh spans all 32 devices, JAX rejects the compilation because the output sharding's device set doesn't match the context mesh's device set.

jax.set_mesh(local_mesh) is required to override the context mesh to the single-replica mesh, making the output shardings compatible. Without it, the function fails.

Suggested Fix

Add jax.set_mesh(local_mesh) around the create_zeros call to scope the JIT to the replica-local device set:

local_mesh = single_replica_shardings[0].mesh
with jax.set_mesh(local_mesh):
  deserialized = create_zeros(tuple(shape_dtype))

Note: jax.set_mesh(local_mesh) is required. Using with local_mesh would not override global mesh correctly.

Related issue: #3164

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions