Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions drjax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,10 @@ def reduce_weighted_mean_impl(x, w):
)

def broadcast_impl(x, *, mesh=None):

return jax.tree_util.tree_map(
lambda x: prim_computations[f'broadcast_{placement}'](x, mesh=mesh), x
lambda arr: prim_computations[f'broadcast_{placement}'](arr, mesh=mesh),
x
)

api.broadcast = _implement_api(broadcast, broadcast_impl)
Expand All @@ -233,7 +235,6 @@ def drjax_program(
*,
placements: Mapping[str, int],
self_module,
use_abstract_mesh: bool = True,
):
"""Patches symbols into current module and call `jax.jit` on the result.

Expand Down Expand Up @@ -262,9 +263,6 @@ def drjax_program(
collectives referencing this name results in undefined behavior).
self_module: The Python module to patch the API when performing DrJAX
tracing.
use_abstract_mesh: Whether to optionally search for jax's abstract mesh when
adding drjax sharding constraints (e.g. making use of drjax compatible
with jax.set_mesh).

Returns:
A decorated function enabling the calling of the DrJAX API. Interoperable
Expand All @@ -279,7 +277,6 @@ def drjax_program(

placed_computations = impls.PlacedComputations(
placements_to_n_elements=placements,
use_abstract_mesh=use_abstract_mesh,
)
prim_computations, primdefs = primitives.register_primitives(
placements=placements
Expand Down
225 changes: 48 additions & 177 deletions drjax/_src/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import functools
import itertools

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -29,39 +28,18 @@ def drjax_program(*, placements):
return api.drjax_program(placements=placements, self_module=api)


@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
@parameterized.named_parameters(
("clients_placed", "clients"), ("XY_placed", "XY")
)
class ApiTest(absltest.TestCase):

def assertShardingEqual(self, arr, sharding):
canonical_array_sharding = jax.sharding.NamedSharding(
arr.sharding.mesh,
# Canonicalize with trailing `None`s to the rank of the input array.
# This canonicalizes across Auto and Explicit axis types, the former
# which may not include trailing `None`s.
jax.sharding.PartitionSpec(*(
axis
for axis, _ in itertools.zip_longest(arr.sharding.spec, arr.shape)
)),
)
self.assertEqual(canonical_array_sharding, sharding)

def test_broadcast_with_placement_in_mesh(self, placement_name, axes_type):
def test_sharded_broadcast(self, placement_name):

@drjax_program(placements={placement_name: 100})
def broadcast_val(val):
return api.broadcast(val)

mesh = jax.sharding.Mesh(
np.array(jax.devices()),
axis_names=("some_axis",),
axis_types=(axes_type,),
)
mesh = jax.sharding.Mesh(np.array(jax.devices()), ("some_axis",))
arg_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("some_axis")
)
Expand All @@ -74,163 +52,79 @@ def broadcast_val(val):
# No clients dimension in the mesh, we don't lay out the clients along that
# nonexistent dimension, but rather replicate them. Notice that we don't
# need to specify the sharding to DrJAX; it should be inferred by GSPMD.
expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None)
self.assertShardingEqual(
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

def test_broadcast_mesh_arg_without_placement(
self, placement_name, axes_type
):
mesh = jax.sharding.Mesh(
np.array(jax.devices()),
axis_names=("some_axis",),
axis_types=(axes_type,),
expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis")
self.assertEqual(
result.sharding, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

def test_broadcast_with_mesh(self, placement_name):
@drjax_program(placements={placement_name: 100})
def broadcast_val(val):
return api.broadcast(val, mesh=mesh)

arg_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("some_axis")
)
result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding))
return api.broadcast(val, mesh=None)

result = broadcast_val(jnp.ones(shape=[8, 8]))
chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8]))
# No clients dimension in the mesh, we don't lay out the clients along that
# nonexistent dimension, but rather replicate them. Notice that we don't
# need to specify the sharding to DrJAX; it should be inferred by GSPMD.
expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None)
self.assertShardingEqual(
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type):
mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape([4, 2]),
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)

@drjax_program(placements={placement_name: 8})
def broadcast_val(val):
return api.broadcast(val, mesh=mesh)

arg_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("some_axis")
)

result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding))

chex.assert_trees_all_close(result, jnp.ones(shape=[8, 8, 8]))
# The result should be sharded across the placement_name axis.
expected_result_pspec = jax.sharding.PartitionSpec(
placement_name, "some_axis", None
)
self.assertShardingEqual(
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

def test_temperature_sensors_example(self, placement_name, axes_type):
def test_temp_sens_example(self, placement_name):
def one_if_over(threshold, value):
return jax.lax.cond(
value > threshold,
lambda: jnp.ones_like(value),
lambda: jnp.zeros_like(value),
)
return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0)

placement_dim = 100
mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape([4, 2]),
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)
jax.set_mesh(mesh)

@drjax_program(placements={placement_name: placement_dim})
def temperature_sensors_example(threshold, values):
def temp_sens_example(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

measurements = jax.device_put(
jnp.arange(placement_dim),
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(placement_name)
),
)

self.assertEqual(temperature_sensors_example(24, measurements), 0.75)
measurements = jnp.arange(placement_dim)

def test_temperature_sensors_example_multiple_placement_values(
self, placement_name, axes_type
):
self.assertEqual(temp_sens_example(24, measurements), 0.75)

def test_temp_sens_example_multiple_placement_values(self, placement_name):
def one_if_over(threshold, value):
return jax.lax.cond(
value > threshold,
lambda: jnp.ones_like(value),
lambda: jnp.zeros_like(value),
)

mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape([4, 2]),
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)
jax.set_mesh(mesh)
return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0)

@drjax_program(placements={placement_name: 100})
def temperature_sensors_example_100_clients(threshold, values):
def temp_sens_example_100_clients(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))

return api.reduce_mean(values_over)

@drjax_program(placements={placement_name: 20})
def temperature_sensors_example_20_clients(threshold, values):
@drjax_program(placements={placement_name: 10})
def temp_sens_example_10_clients(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

placement_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(placement_name)
)
measurements_100 = jax.device_put(jnp.arange(100), placement_sharding)
measurements_20 = jax.device_put(jnp.arange(20), placement_sharding)
measurements_100 = jnp.arange(100)
measurements_10 = jnp.arange(10)

self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75)
self.assertEqual(
temperature_sensors_example_100_clients(24, measurements_100), 0.75
)
self.assertEqual(
temperature_sensors_example_20_clients(3, measurements_20),
0.8,
temp_sens_example_10_clients(3, measurements_10),
0.6,
)
# We should be able to recover the original result flipping back to the
# original function.
self.assertEqual(
temperature_sensors_example_100_clients(24, measurements_100), 0.75
)
self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75)


class ApiErrorsTest(absltest.TestCase):

def test_multiple_placements_raises(self):
placement_name = "XY"
def test_multiple_placements_raises(self, placement_name):

with self.assertRaises(ValueError):

@drjax_program(placements={placement_name: 1, placement_name + "x": 1})
def _(values):
return api.reduce_mean(values)

def test_raises_outside_program_context(self):
def test_raises_outside_program_context(self, placement_name):
with self.assertRaises(api.OperatorUndefinedError):
api.broadcast(jnp.array(0.5))

num_clients = 10

@drjax_program(placements={"xy": num_clients})
@drjax_program(placements={placement_name: num_clients})
def test(values):
return api.reduce_mean(values)

Expand All @@ -241,9 +135,11 @@ def test(values):
with self.assertRaises(api.OperatorUndefinedError):
api.broadcast(jnp.array(0.5))

def test_broadcast_raises_type_error_within_program_context(self):
def test_broadcast_raises_type_error_within_program_context(
self, placement_name
):

@drjax_program(placements={"xy": 1})
@drjax_program(placements={placement_name: 1})
def test(*args):
return api.broadcast(*args)

Expand All @@ -252,9 +148,11 @@ def test(*args):
):
test(jnp.array(0.5), jnp.array(0.5))

def test_map_fn_raises_type_error_within_program_context(self):
def test_map_fn_raises_type_error_within_program_context(
self, placement_name
):

@drjax_program(placements={"xy": 1})
@drjax_program(placements={placement_name: 1})
def test(*args):
return api.map_fn(lambda x: x, *args)

Expand All @@ -263,9 +161,10 @@ def test(*args):
):
test(jnp.array(0.5), jnp.array(0.5))

def test_reduce_sum_raises_type_error_within_program_context(self):

@drjax_program(placements={"xy": 1})
def test_reduce_sum_raises_type_error_within_program_context(
self, placement_name
):
@drjax_program(placements={placement_name: 1})
def test(*args):
return api.reduce_sum(*args)

Expand All @@ -275,9 +174,10 @@ def test(*args):
):
test(jnp.array(0.5), jnp.array(0.5))

def test_reduce_mean_raises_type_error_within_program_context(self):

@drjax_program(placements={"xy": 1})
def test_reduce_mean_raises_type_error_within_program_context(
self, placement_name
):
@drjax_program(placements={placement_name: 1})
def test(*args):
return api.reduce_mean(*args)

Expand All @@ -287,47 +187,18 @@ def test(*args):
):
test(jnp.array(0.5), jnp.array(0.5))

def test_map_fn_error_propagates(self):

def test_map_fn_error_propagates(self, placement_name):
test_msg = "This is a test value error."
def foo(_):
raise ValueError(test_msg)

@drjax_program(placements={"clients": 1})
@drjax_program(placements={placement_name: 1})
def trigger_error(x):
return api.map_fn(foo, x)

with self.assertRaisesRegex(ValueError, test_msg):
trigger_error(jnp.asarray([0]))

def test_apis_with_mixed_mode_mesh_axes_raise_error(self):

mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape([4, 2]),
axis_names=("xy", "some_axis"),
axis_types=(jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Auto),
)
with jax.set_mesh(mesh):
with self.subTest("map"), self.assertRaisesRegex(
ValueError, "Mesh axis types must all be either auto or manual"
):

@drjax_program(placements={"xy": 1})
def test_map(x):
return api.map_fn(lambda arr: arr, x)

test_map(jnp.asarray([0]))

with self.subTest("broadcast"), self.assertRaisesRegex(
ValueError, "Mesh axis types must all be either auto or manual"
):

@drjax_program(placements={"xy": 1})
def test_broadcast(x):
return api.broadcast(x)

test_broadcast(jnp.asarray([0]))


# This allows us to test sharding behavior across multiple devices.
def setUpModule():
Expand Down
Loading
Loading