diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index c43d308..4ad412a 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -164,7 +164,6 @@ def broadcast_to_placement( n_elements = self._placements_to_n_elements[placement] def single_arg_broadcast(x): - unconstrained_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape)) if mesh is None: logging.warning( 'No mesh found; defaulting to fully unconstrained broadcast and' @@ -172,7 +171,7 @@ def single_arg_broadcast(x): ' axis %s.', placement, ) - return unconstrained_tensor + return jnp.broadcast_to(x, (n_elements,) + x.shape) else: if mesh.are_all_axes_auto: if _placement_axis_in_mesh(mesh, placement): @@ -182,7 +181,9 @@ def single_arg_broadcast(x): # the compiler that there are no constraints on this tensor. This # will leave the choices in the hands of the compiler. pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1))) - return _constrain_alike_if_mesh(mesh, unconstrained_tensor, x, pspec) + return _constrain_alike_if_mesh( + mesh, jnp.broadcast_to(x, (n_elements,) + x.shape), x, pspec + ) elif mesh.are_all_axes_explicit: input_sharding = jax.typeof(x).sharding if _placement_axis_in_mesh(mesh, placement): @@ -195,8 +196,8 @@ def single_arg_broadcast(x): out_sharding = jax.sharding.NamedSharding( input_sharding.mesh, P(None, *input_sharding.spec) ) - return jax.sharding.reshard( - unconstrained_tensor, out_shardings=out_sharding + return jnp.broadcast_to( + x, (n_elements,) + x.shape, out_sharding=out_sharding ) else: raise ValueError( @@ -204,7 +205,7 @@ def single_arg_broadcast(x): f' {mesh.axis_types}.' ) - return jax.jit(single_arg_broadcast)(arg) + return single_arg_broadcast(arg) def normalized_broadcast_to_placement( self, @@ -290,9 +291,7 @@ def _constrain_at_placement_with_slices_like(x): pspec = P(placement, *([P.UNCONSTRAINED] * (len(x.shape) - 1))) return _constrain_alike_if_mesh(mesh, x, x[0], pspec) - arg = jax.tree_util.tree_map( - _constrain_at_placement_with_slices_like, arg - ) + arg = jax.tree.map(_constrain_at_placement_with_slices_like, arg) mapped_fn = jax.vmap( # We must not have an `axis_name` argument here in order to work # with any potential `shard_map` inside of `fn`. @@ -305,27 +304,27 @@ def _constrain_at_placement_with_slices_like(x): # In some cases, vmap may prevent placement sharding from propagating. # We ensure placement sharding on the output just in case. result = call_jaxpr(mapped_fn, arg) - return jax.tree_util.tree_map( - _constrain_at_placement_with_slices_like, result - ) + return jax.tree.map(_constrain_at_placement_with_slices_like, result) elif mesh.are_all_axes_explicit: mapped_fn = jax.vmap( + # We must not have an `axis_name` argument here in order to work + # with any potential `shard_map` inside of `fn`. `fn` should not + # contain collectives that operate over the placement axis. fn, - axis_name=placement, in_axes=0, out_axes=0, ) result = call_jaxpr(mapped_fn, arg) # Ensure the result is sharded along the placement axis when using # explicit axes. - return jax.tree_util.tree_map( - lambda arr: jax.sharding.reshard( - arr, - jax.sharding.NamedSharding( + return jax.sharding.reshard( + result, + jax.tree.map( + lambda arr: jax.sharding.NamedSharding( mesh, spec=P(placement, *jax.typeof(arr).sharding.spec[1:]) ), + result, ), - result, ) else: raise ValueError(