Description
broadcast_one_replica_to_all deadlocks during single-replica checkpoint restore when:
- The primary process is still performing slow I/O (e.g., S3 deserialization), and
- Non-primary processes race ahead into
_merge_globalized_replicas, triggering distributed JIT compilation that blocks waiting for all processes.
The first call to _merge_globalized_replicas inside broadcast_one_replica_to_all triggers distributed compilation of the jnp.sum all-reduce. If the primary hasn't arrived yet, compilation deadlocks because XLA requires all processes to enter the same jit call together.
Environment
- JAX 0.9, multi-host GPU (4 workers × 8 H200 GPUs = 32 devices)
- Orbax 0.11.33
- S3 checkpointing
- Mesh shape with TP=8 (pp=1, dp=4, tp=8)
- Single-replica broadcast checkpoint restore
Error
No explicit error — the job hangs indefinitely. Only one non-primary process reaches _merge_globalized_replicas and blocks in XLA compilation waiting for the primary, which is still performing S3 I/O in _deserialize_arrays.
Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L369
In broadcast_one_replica_to_all:
globalized_sharded_subtree = jax.tree.map(
functools.partial(
_globalize_single_replica_arrays,
global_mesh=global_mesh,
replica_axis_index=replica_axis_index,
is_source=is_source,
),
subtree,
)
# Delete immediately to conserve memory.
jax.tree.map(lambda x: x.delete(), subtree)
out_subtree = _merge_globalized_replicas( # ← jax.jit: compilation on first call; subsequent calls use compilation cache
globalized_sharded_subtree, global_mesh
)
out_tree.extend(out_subtree)
jax.block_until_ready(out_subtree)
start = end
The first call to _merge_globalized_replicas triggers distributed JIT compilation (blocks the host thread until all processes present matching HLO). Non-primaries reach this point in milliseconds (they only allocate zeros), while the primary is still performing S3 deserialization. This creates a deadlock.
Suggested Fix
Add a multihost.sync_global_processes barrier between the asymmetric deserialization/zeros-allocation code path and the call to broadcast_one_replica_to_all:
|
shared_state, _ = multislice.broadcast_one_replica_to_all( |
Inside _single_replica_deserialize_and_broadcast:
deserialized = tuple(deserialized)
global_mesh = cast(jax.sharding.NamedSharding, shardings[0]).mesh
# Barrier: ensure all processes enter broadcast together.
# Without this, non-primaries race ahead into _merge_globalized_replicas
# (a multi-process JIT) while the primary is still deserializing, causing deadlock.
multihost.sync_global_processes('single_replica_pre_broadcast') # [This is the proposed fix]
shared_state, _ = multislice.broadcast_one_replica_to_all(
deserialized,
global_mesh,
replica_axis_index,
is_primary,
memory_limit_bytes=broadcast_memory_limit_bytes,
memory_scaling_factor=broadcast_memory_scaling_factor,
)
Questions
- Is this a known issue? Is my understanding correct that distributed JIT compilation requires all processes to arrive at the
jit call simultaneously? Is adding a barrier the recommended fix here, or is there a preferred alternative?
Description
broadcast_one_replica_to_alldeadlocks during single-replica checkpoint restore when:_merge_globalized_replicas, triggering distributed JIT compilation that blocks waiting for all processes.The first call to
_merge_globalized_replicasinsidebroadcast_one_replica_to_alltriggers distributed compilation of thejnp.sumall-reduce. If the primary hasn't arrived yet, compilation deadlocks because XLA requires all processes to enter the samejitcall together.Environment
Error
No explicit error — the job hangs indefinitely. Only one non-primary process reaches
_merge_globalized_replicasand blocks in XLA compilation waiting for the primary, which is still performing S3 I/O in_deserialize_arrays.Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L369
In
broadcast_one_replica_to_all:The first call to
_merge_globalized_replicastriggers distributed JIT compilation (blocks the host thread until all processes present matching HLO). Non-primaries reach this point in milliseconds (they only allocate zeros), while the primary is still performing S3 deserialization. This creates a deadlock.Suggested Fix
Add a
multihost.sync_global_processesbarrier between the asymmetric deserialization/zeros-allocation code path and the call tobroadcast_one_replica_to_all:orbax/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py
Line 1416 in c8be6df
Inside
_single_replica_deserialize_and_broadcast:Questions
jitcall simultaneously? Is adding a barrier the recommended fix here, or is there a preferred alternative?