Description
Thank you for addressing #3164! I've tested the applied fix (with sd_mesh: as mesh context manager) and confirmed that the "incompatible devices" error still occurs. It appears that with sd_mesh: does not override the
active training mesh context, so JAX still resolves JIT compilation against the full multi-device mesh.
The fix should use with jax.set_mesh(sd_mesh): instead of with sd_mesh: to properly override the active mesh context for single-device operations.
Environment
Error
[rank0]: ValueError: Received incompatible devices for jitted computation. Got argument args[0] of broadcast_in_dim with shape float32[1] and device ids [0] on platform GPU and jit's context mesh with device ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] on platform GPU
Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L253,#L268
The current code uses with sd_mesh: to scope single-device operations:
sd_mesh = jax.sharding.Mesh(np.array([s.device]), ('_single',))
with sd_mesh:
source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)
with sd_mesh: enters the mesh as a context manager but does not take precedence over the active training mesh. When jnp.expand_dims triggers JIT compilation (e.g., broadcast_in_dim), JAX resolves the mesh context to the training mesh (32 devices) rather than the intended single-device mesh, causing the incompatible devices error.
Fix
Replace with sd_mesh: with with jax.set_mesh(sd_mesh):
if is_source:
for s in inp.addressable_shards:
sd_mesh = jax.sharding.Mesh(np.array([s.device]), ('_single',))
with jax.set_mesh(sd_mesh):
source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)
...
else:
slice_shape = _get_slice_shape(index, global_shape)
sd_mesh = jax.sharding.Mesh(np.array([d]), ('_single',))
with jax.set_mesh(sd_mesh):
zero_data = jnp.zeros(slice_shape, dtype=inp.dtype, device=d)
device_buffers.append(zero_data)
jax.set_mesh properly overrides the active mesh context for JIT compilation, while with sd_mesh: does not. We have verified that with jax.set_mesh(sd_mesh): resolves this issue in our multi-host training environment.
Questions
What is the difference between with sd_mesh and with jax.set_mesh(sd_mesh) and why is with sd_mesh unable to override the active training mesh context?
Related
Description
Thank you for addressing #3164! I've tested the applied fix (
with sd_mesh:as mesh context manager) and confirmed that the "incompatible devices" error still occurs. It appears thatwith sd_mesh:does not override theactive training mesh context, so JAX still resolves JIT compilation against the full multi-device mesh.
The fix should use with
jax.set_mesh(sd_mesh):instead ofwith sd_mesh: to properly override the active mesh context for single-device operations.Environment
Error
Root Cause
https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L253,#L268
The current code uses
with sd_mesh:to scope single-device operations:with sd_mesh:enters the mesh as a context manager but does not take precedence over the active training mesh. Whenjnp.expand_dimstriggers JIT compilation (e.g.,broadcast_in_dim), JAX resolves the mesh context to the training mesh (32 devices) rather than the intended single-device mesh, causing the incompatible devices error.Fix
Replace
with sd_mesh:with withjax.set_mesh(sd_mesh):jax.set_meshproperly overrides the active mesh context for JIT compilation, whilewith sd_mesh:does not. We have verified that withjax.set_mesh(sd_mesh):resolves this issue in our multi-host training environment.Questions
What is the difference between
with sd_meshand withjax.set_mesh(sd_mesh)and why iswith sd_meshunable to override the active training mesh context?Related