Skip to content

broadcast_one_replica_to_all deadlocks on first resume without barrier before distributed JIT compilation #3181

@evelyn22chen

Description

@evelyn22chen

Description

broadcast_one_replica_to_all deadlocks during single-replica checkpoint restore when:

  1. The primary process is still performing slow I/O (e.g., S3 deserialization), and
  2. Non-primary processes race ahead into _merge_globalized_replicas, triggering distributed JIT compilation that blocks waiting for all processes.

The first call to _merge_globalized_replicas inside broadcast_one_replica_to_all triggers distributed compilation of the jnp.sum all-reduce. If the primary hasn't arrived yet, compilation deadlocks because XLA requires all processes to enter the same jit call together.

Environment

  • JAX 0.9, multi-host GPU (4 workers × 8 H200 GPUs = 32 devices)
  • Orbax 0.11.33
  • S3 checkpointing
  • Mesh shape with TP=8 (pp=1, dp=4, tp=8)
  • Single-replica broadcast checkpoint restore

Error

No explicit error — the job hangs indefinitely. Only one non-primary process reaches _merge_globalized_replicas and blocks in XLA compilation waiting for the primary, which is still performing S3 I/O in _deserialize_arrays.

Root Cause

https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L369

In broadcast_one_replica_to_all:

    globalized_sharded_subtree = jax.tree.map(
        functools.partial(
            _globalize_single_replica_arrays,
            global_mesh=global_mesh,
            replica_axis_index=replica_axis_index,
            is_source=is_source,
        ),
        subtree,
    )
    # Delete immediately to conserve memory.
    jax.tree.map(lambda x: x.delete(), subtree)
    out_subtree = _merge_globalized_replicas(  # ← jax.jit: compilation on first call; subsequent calls use compilation cache
        globalized_sharded_subtree, global_mesh
    )
    out_tree.extend(out_subtree)
    jax.block_until_ready(out_subtree)
    start = end

The first call to _merge_globalized_replicas triggers distributed JIT compilation (blocks the host thread until all processes present matching HLO). Non-primaries reach this point in milliseconds (they only allocate zeros), while the primary is still performing S3 deserialization. This creates a deadlock.

Suggested Fix

Add a multihost.sync_global_processes barrier between the asymmetric deserialization/zeros-allocation code path and the call to broadcast_one_replica_to_all:

shared_state, _ = multislice.broadcast_one_replica_to_all(

Inside _single_replica_deserialize_and_broadcast:

deserialized = tuple(deserialized)
global_mesh = cast(jax.sharding.NamedSharding, shardings[0]).mesh

# Barrier: ensure all processes enter broadcast together.
# Without this, non-primaries race ahead into _merge_globalized_replicas
# (a multi-process JIT) while the primary is still deserializing, causing deadlock.
multihost.sync_global_processes('single_replica_pre_broadcast')  # [This is the proposed fix]

shared_state, _ = multislice.broadcast_one_replica_to_all(
    deserialized,
    global_mesh,
    replica_axis_index,
    is_primary,
    memory_limit_bytes=broadcast_memory_limit_bytes,
    memory_scaling_factor=broadcast_memory_scaling_factor,
)

Questions

  • Is this a known issue? Is my understanding correct that distributed JIT compilation requires all processes to arrive at the jit call simultaneously? Is adding a barrier the recommended fix here, or is there a preferred alternative?

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions