From 21c27d21e883424dc2da04581ade66cc55f8eddb Mon Sep 17 00:00:00 2001 From: DrJAX Team Date: Thu, 30 Apr 2026 12:57:28 -0700 Subject: [PATCH] Add mesh argument to drjax.broadcast and propagate it through implementation. PiperOrigin-RevId: 908329844 --- drjax/_src/api.py | 9 +- drjax/_src/api_test.py | 225 +++++++++------------------------------ drjax/_src/primitives.py | 98 +++++------------ 3 files changed, 77 insertions(+), 255 deletions(-) diff --git a/drjax/_src/api.py b/drjax/_src/api.py index 40d07de..14a3056 100644 --- a/drjax/_src/api.py +++ b/drjax/_src/api.py @@ -218,8 +218,10 @@ def reduce_weighted_mean_impl(x, w): ) def broadcast_impl(x, *, mesh=None): + return jax.tree_util.tree_map( - lambda x: prim_computations[f'broadcast_{placement}'](x, mesh=mesh), x + lambda arr: prim_computations[f'broadcast_{placement}'](arr, mesh=mesh), + x ) api.broadcast = _implement_api(broadcast, broadcast_impl) @@ -233,7 +235,6 @@ def drjax_program( *, placements: Mapping[str, int], self_module, - use_abstract_mesh: bool = True, ): """Patches symbols into current module and call `jax.jit` on the result. @@ -262,9 +263,6 @@ def drjax_program( collectives referencing this name results in undefined behavior). self_module: The Python module to patch the API when performing DrJAX tracing. - use_abstract_mesh: Whether to optionally search for jax's abstract mesh when - adding drjax sharding constraints (e.g. making use of drjax compatible - with jax.set_mesh). Returns: A decorated function enabling the calling of the DrJAX API. Interoperable @@ -279,7 +277,6 @@ def drjax_program( placed_computations = impls.PlacedComputations( placements_to_n_elements=placements, - use_abstract_mesh=use_abstract_mesh, ) prim_computations, primdefs = primitives.register_primitives( placements=placements diff --git a/drjax/_src/api_test.py b/drjax/_src/api_test.py index 3a9298a..2fbb535 100644 --- a/drjax/_src/api_test.py +++ b/drjax/_src/api_test.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import itertools from absl.testing import absltest from absl.testing import parameterized @@ -29,39 +28,18 @@ def drjax_program(*, placements): return api.drjax_program(placements=placements, self_module=api) -@parameterized.product( - placement_name=["clients", "XY"], - axes_type=[ - jax.sharding.AxisType.Auto, - jax.sharding.AxisType.Explicit, - ], +@parameterized.named_parameters( + ("clients_placed", "clients"), ("XY_placed", "XY") ) class ApiTest(absltest.TestCase): - def assertShardingEqual(self, arr, sharding): - canonical_array_sharding = jax.sharding.NamedSharding( - arr.sharding.mesh, - # Canonicalize with trailing `None`s to the rank of the input array. - # This canonicalizes across Auto and Explicit axis types, the former - # which may not include trailing `None`s. - jax.sharding.PartitionSpec(*( - axis - for axis, _ in itertools.zip_longest(arr.sharding.spec, arr.shape) - )), - ) - self.assertEqual(canonical_array_sharding, sharding) - - def test_broadcast_with_placement_in_mesh(self, placement_name, axes_type): + def test_sharded_broadcast(self, placement_name): @drjax_program(placements={placement_name: 100}) def broadcast_val(val): return api.broadcast(val) - mesh = jax.sharding.Mesh( - np.array(jax.devices()), - axis_names=("some_axis",), - axis_types=(axes_type,), - ) + mesh = jax.sharding.Mesh(np.array(jax.devices()), ("some_axis",)) arg_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("some_axis") ) @@ -74,149 +52,65 @@ def broadcast_val(val): # No clients dimension in the mesh, we don't lay out the clients along that # nonexistent dimension, but rather replicate them. Notice that we don't # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) - self.assertShardingEqual( - result, jax.sharding.NamedSharding(mesh, expected_result_pspec) - ) - - def test_broadcast_mesh_arg_without_placement( - self, placement_name, axes_type - ): - mesh = jax.sharding.Mesh( - np.array(jax.devices()), - axis_names=("some_axis",), - axis_types=(axes_type,), + expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis") + self.assertEqual( + result.sharding, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) + def test_broadcast_with_mesh(self, placement_name): @drjax_program(placements={placement_name: 100}) def broadcast_val(val): - return api.broadcast(val, mesh=mesh) - - arg_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec("some_axis") - ) - result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + return api.broadcast(val, mesh=None) + result = broadcast_val(jnp.ones(shape=[8, 8])) chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8])) - # No clients dimension in the mesh, we don't lay out the clients along that - # nonexistent dimension, but rather replicate them. Notice that we don't - # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) - self.assertShardingEqual( - result, jax.sharding.NamedSharding(mesh, expected_result_pspec) - ) - def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type): - mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape([4, 2]), - axis_names=(placement_name, "some_axis"), - axis_types=(axes_type, axes_type), - ) - - @drjax_program(placements={placement_name: 8}) - def broadcast_val(val): - return api.broadcast(val, mesh=mesh) - - arg_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec("some_axis") - ) - - result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) - - chex.assert_trees_all_close(result, jnp.ones(shape=[8, 8, 8])) - # The result should be sharded across the placement_name axis. - expected_result_pspec = jax.sharding.PartitionSpec( - placement_name, "some_axis", None - ) - self.assertShardingEqual( - result, jax.sharding.NamedSharding(mesh, expected_result_pspec) - ) - - def test_temperature_sensors_example(self, placement_name, axes_type): + def test_temp_sens_example(self, placement_name): def one_if_over(threshold, value): - return jax.lax.cond( - value > threshold, - lambda: jnp.ones_like(value), - lambda: jnp.zeros_like(value), - ) + return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0) placement_dim = 100 - mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape([4, 2]), - axis_names=(placement_name, "some_axis"), - axis_types=(axes_type, axes_type), - ) - jax.set_mesh(mesh) @drjax_program(placements={placement_name: placement_dim}) - def temperature_sensors_example(threshold, values): + def temp_sens_example(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) return api.reduce_mean(values_over) - measurements = jax.device_put( - jnp.arange(placement_dim), - jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(placement_name) - ), - ) - - self.assertEqual(temperature_sensors_example(24, measurements), 0.75) + measurements = jnp.arange(placement_dim) - def test_temperature_sensors_example_multiple_placement_values( - self, placement_name, axes_type - ): + self.assertEqual(temp_sens_example(24, measurements), 0.75) + def test_temp_sens_example_multiple_placement_values(self, placement_name): def one_if_over(threshold, value): - return jax.lax.cond( - value > threshold, - lambda: jnp.ones_like(value), - lambda: jnp.zeros_like(value), - ) - - mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape([4, 2]), - axis_names=(placement_name, "some_axis"), - axis_types=(axes_type, axes_type), - ) - jax.set_mesh(mesh) + return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0) @drjax_program(placements={placement_name: 100}) - def temperature_sensors_example_100_clients(threshold, values): + def temp_sens_example_100_clients(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) - @drjax_program(placements={placement_name: 20}) - def temperature_sensors_example_20_clients(threshold, values): + @drjax_program(placements={placement_name: 10}) + def temp_sens_example_10_clients(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) return api.reduce_mean(values_over) - placement_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(placement_name) - ) - measurements_100 = jax.device_put(jnp.arange(100), placement_sharding) - measurements_20 = jax.device_put(jnp.arange(20), placement_sharding) + measurements_100 = jnp.arange(100) + measurements_10 = jnp.arange(10) + self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75) self.assertEqual( - temperature_sensors_example_100_clients(24, measurements_100), 0.75 - ) - self.assertEqual( - temperature_sensors_example_20_clients(3, measurements_20), - 0.8, + temp_sens_example_10_clients(3, measurements_10), + 0.6, ) # We should be able to recover the original result flipping back to the # original function. - self.assertEqual( - temperature_sensors_example_100_clients(24, measurements_100), 0.75 - ) + self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75) - -class ApiErrorsTest(absltest.TestCase): - - def test_multiple_placements_raises(self): - placement_name = "XY" + def test_multiple_placements_raises(self, placement_name): with self.assertRaises(ValueError): @@ -224,13 +118,13 @@ def test_multiple_placements_raises(self): def _(values): return api.reduce_mean(values) - def test_raises_outside_program_context(self): + def test_raises_outside_program_context(self, placement_name): with self.assertRaises(api.OperatorUndefinedError): api.broadcast(jnp.array(0.5)) num_clients = 10 - @drjax_program(placements={"xy": num_clients}) + @drjax_program(placements={placement_name: num_clients}) def test(values): return api.reduce_mean(values) @@ -241,9 +135,11 @@ def test(values): with self.assertRaises(api.OperatorUndefinedError): api.broadcast(jnp.array(0.5)) - def test_broadcast_raises_type_error_within_program_context(self): + def test_broadcast_raises_type_error_within_program_context( + self, placement_name + ): - @drjax_program(placements={"xy": 1}) + @drjax_program(placements={placement_name: 1}) def test(*args): return api.broadcast(*args) @@ -252,9 +148,11 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_map_fn_raises_type_error_within_program_context(self): + def test_map_fn_raises_type_error_within_program_context( + self, placement_name + ): - @drjax_program(placements={"xy": 1}) + @drjax_program(placements={placement_name: 1}) def test(*args): return api.map_fn(lambda x: x, *args) @@ -263,9 +161,10 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_reduce_sum_raises_type_error_within_program_context(self): - - @drjax_program(placements={"xy": 1}) + def test_reduce_sum_raises_type_error_within_program_context( + self, placement_name + ): + @drjax_program(placements={placement_name: 1}) def test(*args): return api.reduce_sum(*args) @@ -275,9 +174,10 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_reduce_mean_raises_type_error_within_program_context(self): - - @drjax_program(placements={"xy": 1}) + def test_reduce_mean_raises_type_error_within_program_context( + self, placement_name + ): + @drjax_program(placements={placement_name: 1}) def test(*args): return api.reduce_mean(*args) @@ -287,47 +187,18 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_map_fn_error_propagates(self): - + def test_map_fn_error_propagates(self, placement_name): test_msg = "This is a test value error." def foo(_): raise ValueError(test_msg) - @drjax_program(placements={"clients": 1}) + @drjax_program(placements={placement_name: 1}) def trigger_error(x): return api.map_fn(foo, x) with self.assertRaisesRegex(ValueError, test_msg): trigger_error(jnp.asarray([0])) - def test_apis_with_mixed_mode_mesh_axes_raise_error(self): - - mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape([4, 2]), - axis_names=("xy", "some_axis"), - axis_types=(jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Auto), - ) - with jax.set_mesh(mesh): - with self.subTest("map"), self.assertRaisesRegex( - ValueError, "Mesh axis types must all be either auto or manual" - ): - - @drjax_program(placements={"xy": 1}) - def test_map(x): - return api.map_fn(lambda arr: arr, x) - - test_map(jnp.asarray([0])) - - with self.subTest("broadcast"), self.assertRaisesRegex( - ValueError, "Mesh axis types must all be either auto or manual" - ): - - @drjax_program(placements={"xy": 1}) - def test_broadcast(x): - return api.broadcast(x) - - test_broadcast(jnp.asarray([0])) - # This allows us to test sharding behavior across multiple devices. def setUpModule(): diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index b4bab3a..f46b646 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -31,7 +31,6 @@ class BroadcastType(Protocol): def __call__( self, x: jnp.ndarray, - mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None, ) -> jnp.ndarray: ... @@ -45,8 +44,8 @@ def _define_broadcast_prim( """Defines and returns broadcast ptimitive and associated binding.""" broadcast_p = extended_core.Primitive(broadcast_name) # Create the primitive - def broadcast_prim_fn(x, *, mesh=None): - return broadcast_p.bind(x, mesh=mesh) + def broadcast_prim_fn(x, **params): + return broadcast_p.bind(x, **params) return (broadcast_p, broadcast_prim_fn) @@ -56,7 +55,6 @@ def _register_broadcast_impls( broadcast_prim_fn: BroadcastType, broadcast_array_eval: BroadcastType, sum_prim_fn: AggType, - placement_str: str, n_elements: int, ) -> None: """Registers implementations for the broadcast primitive. @@ -76,37 +74,12 @@ def _register_broadcast_impls( sum_prim_fn: A callable which binds its arguments to the summation primitive from the placement inserted by this broadcast. Similar to `broadcast_prim_fn`. - placement_str: The name of the placement which this broadcast targets. n_elements: The number of elements present at the placement which this broadcast targets. """ - def broadcast_abstract_eval( - xs, *, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None - ) -> core.ShapedArray: - # If no mesh was provided, we try to use the current abstract mesh. - if mesh is None: - abstract_mesh = jax.sharding.get_abstract_mesh() - else: - abstract_mesh = ( - mesh.abstract_mesh if isinstance(mesh, jax.sharding.Mesh) else mesh - ) - sharding_axis = ( - placement_str - if impls._placement_axis_in_mesh(abstract_mesh, placement_str) # pylint: disable=protected-access - else None - ) - new_sharding = xs.sharding.update( - mesh=abstract_mesh, - spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec), - ) - return core.ShapedArray( - shape=(n_elements,) + xs.shape, - dtype=xs.dtype, - weak_type=xs.weak_type, - sharding=new_sharding, - memory_space=xs.memory_space, - ) + def broadcast_abstract_eval(xs, **_params): + return core.ShapedArray((n_elements,) + xs.shape, xs.dtype) # Abstract eval rule. broadcast_p.def_abstract_eval(broadcast_abstract_eval) @@ -114,32 +87,31 @@ def broadcast_abstract_eval( broadcast_p.def_impl(broadcast_array_eval) # Lowering rule to MLIR. mlir.register_lowering( - broadcast_p, mlir.lower_fun(broadcast_array_eval, multiple_results=False), + broadcast_p, mlir.lower_fun(broadcast_array_eval, multiple_results=False) ) - def broadcast_jvp(primals_in, tangents_in, mesh): - primals_out = broadcast_prim_fn(*primals_in, mesh=mesh) - tangents_out = broadcast_prim_fn(*tangents_in, mesh=mesh) + def broadcast_jvp(primals_in, tangents_in, **params): + primals_out = broadcast_prim_fn(*primals_in, **params) + tangents_out = broadcast_prim_fn(*tangents_in, **params) return primals_out, tangents_out # Registering JVP should allow forward AD. ad.primitive_jvps[broadcast_p] = broadcast_jvp - def broadcast_vjp(cotangents_out, primals_in, mesh): - del mesh # Unused. + def broadcast_vjp(cotangents_out, primals_in, **params): if isinstance(cotangents_out, jax.interpreters.ad.Zero): # We are differerentiating back through a broadcast; the incoming value, # therefore, has the right shape and dtype for the Zero we generate. - return (jax.interpreters.ad.Zero(primals_in.aval.to_ct_aval()),) + return (jax.interpreters.ad.Zero(primals_in.aval),) # This implementation *must* use the sum_prim_fn, rather than the array # implementation of summation, to result in a reduce_sum in the Jaxpr. - return (sum_prim_fn(cotangents_out),) + return (sum_prim_fn(cotangents_out, **params),) ad.primitive_transposes[broadcast_p] = broadcast_vjp - def _batch_broadcast(xs, batched_shape, mesh): + def _batch_broadcast(xs, batched_shape, **params): # We inserted clients dimension in front, so batch dim went down one. - return broadcast_prim_fn(*xs, mesh=mesh), batched_shape[0] + 1 + return broadcast_prim_fn(*xs, **params), batched_shape[0] + 1 # Make sure this can also be batched / mapped. This happens when dispatching # forward AD, I think. @@ -152,8 +124,8 @@ def _define_single_arg_agg_prim( """Defines and returns an aggregation primitive taking a single argument.""" agg_p = extended_core.Primitive(agg_name) # Create the primitive - def agg_prim_fn(x): - return agg_p.bind(x) + def agg_prim_fn(x, **params): + return agg_p.bind(x, **params) return agg_p, agg_prim_fn @@ -181,39 +153,21 @@ def _register_single_arg_agg_impls( aggregation primitive. """ - def agg_abstract_eval(xs) -> core.ShapedArray: - - def aval_with_new_sharding(x): - # We slice away the first dimension in doing the reduction; its gone! - new_sharding = x.sharding.update( - spec=jax.sharding.PartitionSpec(*x.sharding.spec[1:]) - ) - return core.ShapedArray( - shape=x.shape[1:], - dtype=x.dtype, - weak_type=x.weak_type, - sharding=new_sharding, - memory_space=x.memory_space, - ) - - return jax.tree.map(aval_with_new_sharding, xs) + def agg_abstract_eval(xs): + return jax.tree_util.tree_map( + # We slice away the first dimension in doing the reduction; its gone! + lambda x: core.ShapedArray(x.shape[1:], x.dtype), + xs, + ) # Abstract eval rule agg_p.def_abstract_eval(agg_abstract_eval) # Concrete eval rule agg_p.def_impl(agg_array_eval) # Lowering rule to MLIR. - kwargs = {} - # TODO(krush): The mean primitive is buggy: if passed an integer input, its - # abstract eval rule will claim to return an integer output, but its eval rule - # will return a float output. Fix this and remove the cacheable=False, which - # works around the bug. - if jax.version.__version_info__ >= (0, 7): - kwargs['cacheable'] = False mlir.register_lowering( agg_p, mlir.lower_fun(agg_array_eval, multiple_results=False), - **kwargs ) def agg_jvp(primals_in, tangents_in): @@ -229,7 +183,7 @@ def agg_vjp(cotangents_out, primals_in): # generate. This is always correct if jax's symbolic Zero is a static # concept, depending on data flow in the program (rather than e.g. runtime # values). - return (jax.interpreters.ad.Zero(primals_in.aval.to_ct_aval()),) + return (jax.interpreters.ad.Zero(primals_in.aval),) return (vjp_impl(cotangents_out),) ad.primitive_transposes[agg_p] = agg_vjp @@ -238,7 +192,7 @@ def _batch_agg(xs, batched_shape): # Certain jax libs can silently insert the 'batching' dim 'all the way at # the front'; we are about to destroy the front axis by agging, so move # that puppy to the back. Tell the rest of JAX what happened here. - xs = jnp.moveaxis(*xs, *batched_shape, -1) + xs = batching.moveaxis(*xs, *batched_shape, -1) return agg_prim_fn(xs), len(xs.shape) - 2 # Make sure this can also be batched / mapped. This happens when dispatching @@ -284,15 +238,15 @@ def _define_and_register_prims_for_placement( primdef_dict[sum_name] = sum_p primdef_dict[mean_name] = mean_p - def broadcast_array_eval(x, *, mesh): - return impl_defs.broadcast_to_placement(x, placement_str, mesh) + def broadcast_array_eval(x, **params): + mesh = params.get('mesh') + return impl_defs.broadcast_to_placement(x, placement_str, mesh=mesh) _register_broadcast_impls( broadcast_p, broadcast_prim_fn, broadcast_array_eval, sum_prim_fn, - placement_str, n_elements, )