Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions drjax/_src/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,14 @@ def broadcast_to_placement(
n_elements = self._placements_to_n_elements[placement]

def single_arg_broadcast(x):
unconstrained_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape))
if mesh is None:
logging.warning(
'No mesh found; defaulting to fully unconstrained broadcast and'
' *NOT* adding sharding constraints over the requested placement'
' axis %s.',
placement,
)
return unconstrained_tensor
return jnp.broadcast_to(x, (n_elements,) + x.shape)
else:
if mesh.are_all_axes_auto:
if _placement_axis_in_mesh(mesh, placement):
Expand All @@ -182,7 +181,9 @@ def single_arg_broadcast(x):
# the compiler that there are no constraints on this tensor. This
# will leave the choices in the hands of the compiler.
pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1)))
return _constrain_alike_if_mesh(mesh, unconstrained_tensor, x, pspec)
return _constrain_alike_if_mesh(
mesh, jnp.broadcast_to(x, (n_elements,) + x.shape), x, pspec
)
elif mesh.are_all_axes_explicit:
input_sharding = jax.typeof(x).sharding
if _placement_axis_in_mesh(mesh, placement):
Expand All @@ -195,16 +196,16 @@ def single_arg_broadcast(x):
out_sharding = jax.sharding.NamedSharding(
input_sharding.mesh, P(None, *input_sharding.spec)
)
return jax.sharding.reshard(
unconstrained_tensor, out_shardings=out_sharding
return jnp.broadcast_to(
x, (n_elements,) + x.shape, out_sharding=out_sharding
)
else:
raise ValueError(
'Mesh axis types must all be either auto or manual, but got'
f' {mesh.axis_types}.'
)

return jax.jit(single_arg_broadcast)(arg)
return single_arg_broadcast(arg)

def normalized_broadcast_to_placement(
self,
Expand Down Expand Up @@ -290,9 +291,7 @@ def _constrain_at_placement_with_slices_like(x):
pspec = P(placement, *([P.UNCONSTRAINED] * (len(x.shape) - 1)))
return _constrain_alike_if_mesh(mesh, x, x[0], pspec)

arg = jax.tree_util.tree_map(
_constrain_at_placement_with_slices_like, arg
)
arg = jax.tree.map(_constrain_at_placement_with_slices_like, arg)
mapped_fn = jax.vmap(
# We must not have an `axis_name` argument here in order to work
# with any potential `shard_map` inside of `fn`.
Expand All @@ -305,27 +304,27 @@ def _constrain_at_placement_with_slices_like(x):
# In some cases, vmap may prevent placement sharding from propagating.
# We ensure placement sharding on the output just in case.
result = call_jaxpr(mapped_fn, arg)
return jax.tree_util.tree_map(
_constrain_at_placement_with_slices_like, result
)
return jax.tree.map(_constrain_at_placement_with_slices_like, result)
elif mesh.are_all_axes_explicit:
mapped_fn = jax.vmap(
# We must not have an `axis_name` argument here in order to work
# with any potential `shard_map` inside of `fn`. `fn` should not
# contain collectives that operate over the placement axis.
fn,
axis_name=placement,
in_axes=0,
out_axes=0,
)
result = call_jaxpr(mapped_fn, arg)
# Ensure the result is sharded along the placement axis when using
# explicit axes.
return jax.tree_util.tree_map(
lambda arr: jax.sharding.reshard(
arr,
jax.sharding.NamedSharding(
return jax.sharding.reshard(
result,
jax.tree.map(
lambda arr: jax.sharding.NamedSharding(
mesh, spec=P(placement, *jax.typeof(arr).sharding.spec[1:])
),
result,
),
result,
)
else:
raise ValueError(
Expand Down
Loading