diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index ce1fc997d..ac7da6a94 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -63,6 +63,10 @@ jobs: run: | python -m pytest tests/cli/utils/ -v --tb=short + - name: Run shared mesh and topology tests + run: | + python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short + - name: Run perf tests run: | python -m pytest tests/perf/ -v --tb=short diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 73694881a..f1fa7561d 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -120,7 +120,7 @@ jobs: - name: Run tunix tests not covered by the above categories run: | # This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here. - python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$? + python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$? if [ "${code:-0}" = "5" ]; then echo "No tests collected (expected)." exit 0 diff --git a/tests/cli/config_test.py b/tests/cli/config_test.py index 575e3136e..086d21f8c 100644 --- a/tests/cli/config_test.py +++ b/tests/cli/config_test.py @@ -27,6 +27,9 @@ from tunix.sft import peft_trainer from tunix.tests import test_common as tc from tunix.utils import env_utils +from tunix.utils import mesh as mesh_lib + +os.environ.setdefault("HF_TOKEN", "TestToken") class ConfigTest(parameterized.TestCase): @@ -262,7 +265,7 @@ def test_learning_rate_schedule_valid(self, overrides): self.assertIsNotNone(lr_schedule) self.assertTrue(callable(lr_schedule), "lr_schedule should be callable") - # --- Tests for create_mesh --- + # --- Tests for mesh config parsing and mesh creation --- @parameterized.named_parameters( dict( testcase_name="valid_1d", @@ -311,40 +314,48 @@ def test_create_mesh_valid( ): mock_device_count_fn.return_value = mock_num_devices hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) - mesh = hp.create_mesh("model_config") - self.assertEqual( - mesh, - jax.make_mesh( - expected[0], - expected[1], - axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), - ), - ) + axis_shapes, axis_names = hp._parse_mesh_config("model_config") + expected_mesh = object() - def test_create_mesh_with_assigned_devices(self): - raw_keys = { - "model_config": { - "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} - } - } - hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) - assigned_devices = ["d0", "d1", "d2", "d3"] + with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock: + mesh = mesh_lib.create_mesh(axis_shapes, axis_names) - class FakeMesh: + make_mesh_mock.assert_called_once_with( + expected[0], + expected[1], + axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), + ) + self.assertIs(mesh, expected_mesh) + + def test_create_mesh_with_assigned_devices(self): + raw_keys = { + "model_config": { + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} + } + } + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) + axis_shapes, axis_names = hp._parse_mesh_config("model_config") + assigned_devices = ["d0", "d1", "d2", "d3"] - def __init__(self, devices, axis_names, axis_types=None): - self.devices = devices - self.axis_names = axis_names - self.axis_types = axis_types + class FakeMesh: - with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): - mesh = hp.create_mesh("model_config", devices=assigned_devices) + def __init__(self, devices, axis_names, axis_types=None): + self.devices = devices + self.axis_names = axis_names + self.axis_types = axis_types - self.assertEqual(mesh.devices.shape, (2, 2)) - self.assertSequenceEqual( - mesh.devices.flatten().tolist(), assigned_devices + with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): + mesh = mesh_lib.create_mesh( + axis_shapes, + axis_names, + devices=assigned_devices, ) - self.assertEqual(mesh.axis_names, ("x", "y")) + + self.assertEqual(mesh.devices.shape, (2, 2)) + self.assertSequenceEqual( + mesh.devices.flatten().tolist(), assigned_devices + ) + self.assertEqual(mesh.axis_names, ("x", "y")) @parameterized.named_parameters( dict( @@ -424,11 +435,12 @@ def test_create_mesh_invalid( mock_num_devices, error_regex, ): - mock_device_count_fn.return_value = mock_num_devices - with self.assertRaisesRegex(ValueError, error_regex): - nested_dict = self.convert_nested_dict_to_list(raw_keys) - hp = self.initialize_config(nested_dict) - hp.create_mesh("model_config") + mock_device_count_fn.return_value = mock_num_devices + with self.assertRaisesRegex(ValueError, error_regex): + nested_dict = self.convert_nested_dict_to_list(raw_keys) + hp = self.initialize_config(nested_dict) + axis_shapes, axis_names = hp._parse_mesh_config("model_config") + mesh_lib.create_mesh(axis_shapes, axis_names) @parameterized.named_parameters( dict( diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index f8a8192eb..d2cadc6d1 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -753,6 +753,107 @@ def __init__(self, devices, axis_names, axis_types=None): role_to_mesh[rl_cluster_lib.Role.ACTOR], ) + def test_split_mesh_delegates_device_allocation_to_mesh_utils(self): + extra = """ +training_mode: "agentic_grpo" +data_module: "tunix.cli.recipes.deepscaler_data" +apply_chat_template_to_dataset: false +data_config: + train_data_path: "gs://fake/train.json" + eval_data_path: "gs://fake/eval.parquet" +prompt_key: "prompts" +reward_functions: [] +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_grpo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 + context_ratio: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + actor_model_config = pipeline.config["actor_model_config"] + if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig): + actor_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + } + pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"} + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + } + + fake_devices = ["a0", "a1", "r0", "r1"] + allocated_devices = { + "actor_model_config": ["a0", "a1"], + "rollout_model_config": ["r0", "r1"], + } + created_mesh_devices = {} + + def fake_create_mesh(axis_shapes, axis_names, devices=None): + model_key = ( + "actor_model_config" + if axis_shapes == (1, 2) and axis_names == ("fsdp", "tp") + and "actor_model_config" not in created_mesh_devices + else "rollout_model_config" + ) + created_mesh_devices[model_key] = list(devices) + return (model_key, tuple(devices)) + + with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices): + with mock.patch.object( + grpo_main.mesh_lib, + "allocate_named_mesh_device_slices", + return_value=allocated_devices, + ) as allocate_mock: + with mock.patch.object( + grpo_main.mesh_lib, + "create_mesh", + side_effect=fake_create_mesh, + ): + role_to_mesh = pipeline.create_role_to_mesh() + + allocate_mock.assert_called_once_with( + [ + ("actor_model_config", 2), + ("rollout_model_config", 2), + ], + devices=fake_devices, + ) + self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"]) + self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"]) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ACTOR], + ("actor_model_config", ("a0", "a1")), + ) + self.assertIs( + role_to_mesh[rl_cluster_lib.Role.REFERENCE], + role_to_mesh[rl_cluster_lib.Role.ACTOR], + ) + self.assertEqual( + role_to_mesh[rl_cluster_lib.Role.ROLLOUT], + ("rollout_model_config", ("r0", "r1")), + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/utils/mesh_utils_test.py b/tests/utils/mesh_utils_test.py new file mode 100644 index 000000000..44cc4dd53 --- /dev/null +++ b/tests/utils/mesh_utils_test.py @@ -0,0 +1,1265 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +import jax +from tunix.utils import mesh + + +class MeshUtilsTest(absltest.TestCase): + + def test_device_attr_calls_callable_attributes(self): + class FakeDevice: + + def coords(self): + return (1, 2, 3) + + self.assertEqual(mesh.device_attr(FakeDevice(), "coords"), (1, 2, 3)) + self.assertIsNone(mesh.device_attr(FakeDevice(), "missing")) + + def test_device_host_key_prefers_slice_and_process_metadata(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 7 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 7)) + + def test_device_host_key_falls_back_to_slice_and_task_id(self): + class FakeDevice: + + def __init__(self): + self.slice = 3 + self.task_id = 9 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (3, 9)) + + def test_device_host_key_prefers_logical_task_over_process_index(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 0 + self.logical_task = 7 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 7)) + + def test_device_host_key_prefers_task_id_over_process_index(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 0 + self.task_id = 9 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 9)) + + def test_device_host_key_returns_none_without_task_metadata(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.device_host_key(FakeDevice())) + + def test_device_slice_id_prefers_slice_index_then_slice(self): + class SliceIndexDevice: + + def __init__(self): + self.slice_index = 4 + + class SliceDevice: + + def __init__(self): + self.slice = 7 + + self.assertEqual(mesh.device_slice_id(SliceIndexDevice()), 4) + self.assertEqual(mesh.device_slice_id(SliceDevice()), 7) + self.assertIsNone(mesh.device_slice_id(object())) + + def test_group_devices_by_slice_preserves_first_seen_order(self): + class FakeDevice: + + def __init__(self, device_id, slice_index): + self.id = device_id + self.slice_index = slice_index + + grouped = mesh.group_devices_by_slice([ + FakeDevice(0, 2), + FakeDevice(1, 2), + FakeDevice(2, 1), + FakeDevice(3, 1), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0, 1], [2, 3]]) + + def test_find_candidate_coord_boxes_finds_contiguous_boxes(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (1, 0, 0)), + FakeDevice(2, (0, 1, 0)), + FakeDevice(3, (1, 1, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual( + mesh.find_candidate_coord_boxes(topology, 4), + [ + ( + (0, 0, 0), + (2, 2, 1), + ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + ) + ], + ) + + def test_find_candidate_coord_boxes_skips_missing_coords(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (1, 0, 0)), + FakeDevice(2, (1, 1, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual(mesh.find_candidate_coord_boxes(topology, 4), []) + + def test_find_candidate_coord_boxes_can_return_multiple_candidates(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (1, 0, 0)), + FakeDevice(2, (2, 0, 0)), + FakeDevice(3, (3, 0, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual( + mesh.find_candidate_coord_boxes(topology, 2), + [ + ((0, 0, 0), (2, 1, 1), ((0, 0, 0), (1, 0, 0))), + ((1, 0, 0), (2, 1, 1), ((1, 0, 0), (2, 0, 0))), + ((2, 0, 0), (2, 1, 1), ((2, 0, 0), (3, 0, 0))), + ], + ) + + def test_find_candidate_coord_boxes_rejects_split_chip_candidates(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (0, 0, 0), 1), + FakeDevice(2, (1, 0, 0), 0), + FakeDevice(3, (1, 0, 0), 1), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertEqual( + mesh.find_candidate_coord_boxes(topology, 1), + [], + ) + + def test_find_host_aligned_candidate_coord_boxes_respects_exact_host_shape(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [] + device_id = 0 + for x in range(4): + for y in range(4): + for z in range(2): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice(device_id, (x, y, z), core_on_chip)) + device_id += 1 + + topology = mesh.get_coord_topology(fake_devices) + + candidate_boxes = mesh.find_host_aligned_candidate_coord_boxes( + topology, 8, (2, 2, 1, 2) + ) + + self.assertLen(candidate_boxes, 8) + self.assertContainsSubset( + [ + ( + (0, 0, 0, 0), + (2, 2, 1, 2), + ( + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (1, 0, 0, 0), + (1, 0, 0, 1), + (1, 1, 0, 0), + (1, 1, 0, 1), + ), + ), + ( + (0, 0, 1, 0), + (2, 2, 1, 2), + ( + (0, 0, 1, 0), + (0, 0, 1, 1), + (0, 1, 1, 0), + (0, 1, 1, 1), + (1, 0, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 0), + (1, 1, 1, 1), + ), + ), + ], + candidate_boxes, + ) + + def test_candidate_uses_whole_chips_requires_all_cores(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + topology = mesh.get_coord_topology([ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (0, 0, 0), 1), + FakeDevice(2, (1, 0, 0), 0), + FakeDevice(3, (1, 0, 0), 1), + ]) + + self.assertFalse( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0, 0), (1, 0, 0, 0)], + ) + ) + self.assertTrue( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0, 0), (0, 0, 0, 1), (1, 0, 0, 0), (1, 0, 0, 1)], + ) + ) + + def test_satisfies_host_bound_shape_rejects_ragged_coords(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + host_devices = [ + FakeDevice(0, (1, 1, 0)), + FakeDevice(1, (1, 0, 1)), + FakeDevice(2, (0, 1, 1)), + FakeDevice(3, (1, 1, 1)), + ] + + self.assertFalse( + mesh._satisfies_host_bound_shape( + host_devices, + (2, 2, 1), + 4, + ) + ) + + def test_get_coord_topology_builds_bounding_box(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (2, 1, 0)), + FakeDevice(1, (3, 1, 0)), + FakeDevice(2, (2, 2, 0)), + FakeDevice(3, (3, 2, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertIsNotNone(topology) + self.assertEqual(topology.num_dims, 3) + self.assertEqual(topology.max_shape, (2, 2, 1)) + self.assertEqual(topology.all_coords, ((2, 1, 0), (3, 1, 0), (2, 2, 0), (3, 2, 0))) + + def test_get_coord_topology_rejects_duplicate_coords(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0))] + + self.assertIsNone(mesh.get_coord_topology(fake_devices)) + + def test_get_coord_topology_uses_core_on_chip_to_disambiguate_devices(self): + class FakeDevice: + + def __init__(self, coords, core_on_chip): + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((0, 0, 0), 1), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertIsNotNone(topology) + self.assertEqual(topology.all_coords, ((0, 0, 0, 0), (0, 0, 0, 1))) + + def test_get_coord_topology_rejects_empty_device_list(self): + self.assertIsNone(mesh.get_coord_topology([])) + + def test_get_coord_topology_rejects_mismatched_coord_dimensions(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0, 1))] + + self.assertIsNone(mesh.get_coord_topology(fake_devices)) + + def test_summarize_devices_for_logging_includes_id_coords_and_host(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index, slice_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + self.slice_index = slice_index + + self.assertEqual( + mesh.summarize_devices_for_logging([FakeDevice(11, (1, 2, 0), 5, 6)]), + [{"id": 11, "coords": (1, 2, 0), "host": (6, 5)}], + ) + + def test_group_devices_by_host_groups_equal_sized_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + grouped = mesh.group_devices_by_host([ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0, 1], [2, 3]]) + + def test_allocate_named_mesh_device_slices_uses_logical_task_host_groups(self): + class FakeDevice: + + def __init__(self, device_id, logical_task, coords): + self.id = device_id + self.process_index = 0 + self.logical_task = logical_task + self.coords = coords + + fake_devices = [] + for device_id in range(16): + host_index = device_id // 2 + fake_devices.append( + FakeDevice( + device_id, + device_id % 2, + (host_index % 2, (host_index // 2) % 2, host_index // 4), + ) + ) + + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 8)], + devices=fake_devices, + ) + + self.assertEqual( + [device.id for device in allocated["actor"]], + [0, 2, 4, 6, 8, 10, 12, 14], + ) + + def test_group_devices_by_host_returns_none_without_host_metadata(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.group_devices_by_host([FakeDevice()])) + + def test_group_devices_by_host_returns_none_for_inconsistent_host_sizes(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + self.assertIsNone( + mesh.group_devices_by_host([ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + ]) + ) + + def test_host_mesh_shape_infers_consistent_per_host_shape(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + FakeDevice((2, 1, 0), 1), + FakeDevice((3, 1, 0), 1), + ] + + self.assertEqual(mesh.host_mesh_shape(fake_devices), (2, 2, 1)) + + def test_host_mesh_shape_returns_none_for_sparse_host_box(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((1, 1, 0), 0), + ] + + self.assertIsNone(mesh.host_mesh_shape(fake_devices)) + + def test_host_mesh_shape_returns_none_for_inconsistent_host_shapes(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + ] + + self.assertIsNone(mesh.host_mesh_shape(fake_devices)) + + def test_divisors_returns_sorted_unique_factors(self): + self.assertEqual(mesh._divisors(12), [1, 2, 3, 4, 6, 12]) + + def test_enumerate_box_shapes_returns_shapes_with_requested_volume(self): + self.assertEqual( + mesh._enumerate_box_shapes(4, (4, 2, 2)), + [(1, 2, 2), (2, 1, 2), (2, 2, 1), (4, 1, 1)], + ) + + def test_coord_box_score_prefers_host_aligned_boxes(self): + aligned_score = mesh._coord_box_score((0, 0, 0), (2, 2, 1), (2, 2, 1)) + unaligned_score = mesh._coord_box_score((1, 0, 0), (2, 2, 1), (2, 2, 1)) + + self.assertLess(aligned_score, unaligned_score) + + def test_select_best_candidate_coords_prefers_host_aligned_box(self): + candidate_boxes = [ + ((1, 0, 0), (2, 2, 1), ((1, 0, 0), (1, 1, 0), (2, 0, 0), (2, 1, 0))), + ((0, 0, 0), (2, 2, 1), ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0))), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, (2, 2, 1)), + [(0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)], + ) + + def test_select_best_candidate_coords_prefers_chip_host_aligned_box_with_core_dimension(self): + candidate_boxes = [ + ( + (0, 0, 0, 0), + (1, 2, 2, 2), + ( + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (0, 0, 1, 0), + (0, 0, 1, 1), + (0, 1, 1, 0), + (0, 1, 1, 1), + ), + ), + ( + (0, 0, 0, 0), + (2, 2, 1, 2), + ( + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (1, 0, 0, 0), + (1, 0, 0, 1), + (1, 1, 0, 0), + (1, 1, 0, 1), + ), + ), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, (2, 2, 1, 2)), + [ + (0, 0, 0, 0), + (0, 0, 0, 1), + (0, 1, 0, 0), + (0, 1, 0, 1), + (1, 0, 0, 0), + (1, 0, 0, 1), + (1, 1, 0, 0), + (1, 1, 0, 1), + ], + ) + + def test_select_best_candidate_coords_prefers_more_compact_shape(self): + candidate_boxes = [ + ((0, 0, 0), (1, 4, 1), ((0, 0, 0), (0, 1, 0), (0, 2, 0), (0, 3, 0))), + ((0, 0, 0), (2, 2, 1), ((0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0))), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, None), + [(0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0)], + ) + + def test_select_best_candidate_coords_uses_start_as_tiebreaker(self): + candidate_boxes = [ + ((2, 0, 0), (2, 1, 1), ((2, 0, 0), (3, 0, 0))), + ((0, 0, 0), (2, 1, 1), ((0, 0, 0), (1, 0, 0))), + ] + + self.assertEqual( + mesh.select_best_candidate_coords(candidate_boxes, None), + [(0, 0, 0), (1, 0, 0)], + ) + + def test_select_best_candidate_coords_returns_none_without_candidates(self): + self.assertIsNone(mesh.select_best_candidate_coords([], (2, 2, 1))) + + def test_device_mesh_coords_appends_core_on_chip_when_present(self): + class FakeDevice: + + def __init__(self): + self.coords = (1, 2, 0) + self.core_on_chip = 1 + + self.assertEqual( + mesh.device_mesh_coords(FakeDevice()), + (1, 2, 0, 1), + ) + + def test_device_mesh_coords_returns_none_without_coords(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.device_mesh_coords(FakeDevice())) + + def test_known_host_mesh_shape_returns_none_for_unknown_device_family(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0, 0) + self.device_kind = "unknown" + + self.assertIsNone(mesh.known_host_mesh_shape([FakeDevice()])) + + def test_known_host_mesh_shape_returns_none_when_coord_rank_mismatches_bounds(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0) + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice() for _ in range(128)] + + self.assertIsNone(mesh.known_host_mesh_shape(fake_devices)) + + def test_resolve_per_host_mesh_shape_returns_inferred_shape(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + FakeDevice((2, 1, 0), 1), + FakeDevice((3, 1, 0), 1), + ] + + self.assertEqual(mesh.resolve_per_host_mesh_shape(fake_devices), (2, 2, 1)) + + def test_known_host_mesh_shape_uses_static_topology_metadata(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0, 0) + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice() for _ in range(128)] + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (2, 2, 1), + ) + + def test_known_host_mesh_shape_uses_single_host_bounds_for_tpu7x_2(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0))] + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (1, 1, 1), + ) + + def test_known_host_mesh_shape_appends_core_dimension_when_present(self): + class FakeDevice: + + def __init__(self, coords, core_on_chip): + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + for x in range(4): + for y in range(4): + for z in range(4): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice((x, y, z), core_on_chip)) + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (2, 2, 1, 2), + ) + + def test_resolve_per_host_mesh_shape_raises_on_mismatch(self): + class FakeDevice: + + def __init__(self, device_id, coords, logical_task): + self.id = device_id + self.coords = coords + self.logical_task = logical_task + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (2, 0, 0), 0), + FakeDevice(3, (3, 0, 0), 0), + FakeDevice(4, (0, 0, 1), 1), + FakeDevice(5, (1, 0, 1), 1), + FakeDevice(6, (2, 0, 1), 1), + FakeDevice(7, (3, 0, 1), 1), + ] + + with self.assertRaisesRegex(ValueError, "does not match known host bounds"): + mesh.resolve_per_host_mesh_shape(fake_devices) + + def test_allocate_named_mesh_device_slices_prefers_coord_boxes(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0, 0)), + FakeDevice(1, (0, 0, 0, 1)), + FakeDevice(2, (1, 0, 0, 0)), + FakeDevice(3, (1, 0, 0, 1)), + FakeDevice(4, (2, 0, 0, 0)), + FakeDevice(5, (2, 0, 0, 1)), + FakeDevice(6, (3, 0, 0, 0)), + FakeDevice(7, (3, 0, 0, 1)), + FakeDevice(8, (0, 1, 0, 0)), + FakeDevice(9, (0, 1, 0, 1)), + FakeDevice(10, (1, 1, 0, 0)), + FakeDevice(11, (1, 1, 0, 1)), + FakeDevice(12, (2, 1, 0, 0)), + FakeDevice(13, (2, 1, 0, 1)), + FakeDevice(14, (3, 1, 0, 0)), + FakeDevice(15, (3, 1, 0, 1)), + ] + + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 8)], + devices=fake_devices, + ) + + self.assertEqual( + [device.id for device in allocated["actor"]], + [0, 1, 2, 3, 8, 9, 10, 11], + ) + + def test_allocate_devices_by_coords_uses_core_on_chip_dimension(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(4): + for y in range(4): + for z in range(2): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice(device_id, (x, y, z), core_on_chip)) + device_id += 1 + + allocated = mesh.allocate_devices_by_coords(fake_devices, 8) + + self.assertEqual( + [device.id for device in allocated], + [0, 1, 4, 5, 16, 17, 20, 21], + ) + + def test_allocate_devices_by_coords_returns_none_without_coord_topology(self): + class FakeDevice: + + def __init__(self, process_index): + self.process_index = process_index + + self.assertIsNone( + mesh.allocate_devices_by_coords([FakeDevice(0), FakeDevice(0)], 2) + ) + + def test_allocate_devices_by_coords_returns_best_contiguous_box(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (0, 1, 0), 0), + FakeDevice(3, (1, 1, 0), 0), + FakeDevice(4, (2, 0, 0), 1), + FakeDevice(5, (3, 0, 0), 1), + FakeDevice(6, (2, 1, 0), 1), + FakeDevice(7, (3, 1, 0), 1), + ] + + allocated = mesh.allocate_devices_by_coords(fake_devices, 4) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) + + def test_allocate_devices_allocates_single_mesh(self): + fake_devices = [object(), object()] + + with mock.patch.object( + mesh, + "allocate_devices_by_coords", + return_value=fake_devices, + ) as allocate_mock: + allocated = mesh.allocate_devices(2, devices=fake_devices) + + allocate_mock.assert_called_once_with( + fake_devices, + 2, + ) + self.assertIs(allocated, fake_devices) + + def test_allocate_devices_returns_updated_state_for_incremental_use(self): + fake_devices = [object(), object(), object()] + + with mock.patch.object( + mesh, + "allocate_devices_by_coords", + side_effect=[fake_devices[:1], fake_devices[1:]], + ): + assigned_devices, next_state = mesh.allocate_devices( + 1, + devices=fake_devices, + return_state=True, + ) + remaining_devices = list(next_state.remaining_devices) + assigned_devices_2, final_state = mesh.allocate_devices( + 2, + allocation_state=next_state, + return_state=True, + ) + + self.assertEqual(assigned_devices, fake_devices[:1]) + self.assertEqual(remaining_devices, fake_devices[1:]) + self.assertEqual(assigned_devices_2, fake_devices[1:]) + self.assertEqual(list(final_state.remaining_devices), []) + self.assertEqual(final_state.used_device_count, 3) + + def test_allocate_devices_rejects_devices_and_state_together(self): + fake_devices = [object()] + allocation_state = mesh.DeviceAllocationState( + remaining_devices=tuple(fake_devices), + remaining_host_groups=None, + full_devices_per_host=0, + host_bound_shape=None, + host_bound_device_count=None, + total_device_count=1, + ) + + with self.assertRaisesRegex( + ValueError, + "Pass either devices or allocation_state to allocate_devices, not both", + ): + mesh.allocate_devices( + 1, + devices=fake_devices, + allocation_state=allocation_state, + ) + + def test_allocate_devices_prefers_single_slice_before_cross_slice(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, coords): + self.id = device_id + self.slice_index = slice_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (2, 0, 0)), + FakeDevice(2, 0, (4, 0, 0)), + FakeDevice(3, 0, (6, 0, 0)), + FakeDevice(4, 1, (1, 0, 0)), + FakeDevice(5, 1, (3, 0, 0)), + FakeDevice(6, 1, (5, 0, 0)), + FakeDevice(7, 1, (7, 0, 0)), + ] + + allocated = mesh.allocate_devices(4, devices=fake_devices) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) + + def test_allocate_devices_spills_to_next_slice_in_order(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, coords): + self.id = device_id + self.slice_index = slice_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (2, 0, 0)), + FakeDevice(2, 0, (4, 0, 0)), + FakeDevice(3, 0, (6, 0, 0)), + FakeDevice(4, 1, (1, 0, 0)), + FakeDevice(5, 1, (3, 0, 0)), + FakeDevice(6, 1, (5, 0, 0)), + FakeDevice(7, 1, (7, 0, 0)), + ] + + allocated = mesh.allocate_devices(6, devices=fake_devices) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3, 4, 5]) + + def test_allocate_named_mesh_device_slices_calls_allocate_devices_in_loop(self): + fake_devices = [object(), object(), object()] + state_0 = mesh.DeviceAllocationState( + remaining_devices=tuple(fake_devices), + remaining_host_groups=None, + full_devices_per_host=0, + host_bound_shape=None, + host_bound_device_count=None, + total_device_count=3, + used_device_count=0, + ) + state_1 = mesh.DeviceAllocationState( + remaining_devices=tuple(fake_devices[1:]), + remaining_host_groups=None, + full_devices_per_host=0, + host_bound_shape=None, + host_bound_device_count=None, + total_device_count=3, + used_device_count=1, + ) + state_2 = mesh.DeviceAllocationState( + remaining_devices=(), + remaining_host_groups=None, + full_devices_per_host=0, + host_bound_shape=None, + host_bound_device_count=None, + total_device_count=3, + used_device_count=3, + ) + + with mock.patch.object( + mesh, + "allocate_devices", + side_effect=[ + ([fake_devices[0]], state_1), + ([fake_devices[1], fake_devices[2]], state_2), + ], + ) as allocate_mock, mock.patch.object( + mesh, + "_create_device_allocation_state", + return_value=state_0, + ) as state_mock, mock.patch.object( + mesh.logging, + "warning", + ) as warning_mock: + allocated = mesh.allocate_named_mesh_device_slices( + [("mesh1", 1), ("mesh2", 2)], + devices=fake_devices, + ) + + state_mock.assert_called_once_with(fake_devices) + self.assertEqual(allocate_mock.call_count, 2) + self.assertEqual( + allocate_mock.call_args_list, + [ + mock.call( + 1, + mesh_name="mesh1", + allocation_state=state_0, + return_state=True, + ), + mock.call( + 2, + mesh_name="mesh2", + allocation_state=state_1, + return_state=True, + ), + ], + ) + warning_mock.assert_not_called() + self.assertEqual( + allocated, + {"mesh1": [fake_devices[0]], "mesh2": [fake_devices[1], fake_devices[2]]}, + ) + + @mock.patch.object(jax, "device_count") + def test_create_mesh_uses_jax_make_mesh_without_assigned_devices( + self, mock_device_count_fn + ): + mock_device_count_fn.return_value = 4 + expected_mesh = object() + + with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock: + created_mesh = mesh.create_mesh((2, 2), ("x", "y")) + + make_mesh_mock.assert_called_once_with( + (2, 2), + ("x", "y"), + axis_types=(jax.sharding.AxisType.Auto,) * 2, + ) + self.assertIs(created_mesh, expected_mesh) + + def test_create_mesh_uses_assigned_devices(self): + assigned_devices = ["d0", "d1", "d2", "d3"] + + class FakeMesh: + + def __init__(self, devices, axis_names, axis_types=None): + self.devices = devices + self.axis_names = axis_names + self.axis_types = axis_types + + with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): + created_mesh = mesh.create_mesh( + (2, 2), + ("x", "y"), + devices=assigned_devices, + ) + + self.assertEqual(created_mesh.devices.shape, (2, 2)) + self.assertEqual( + created_mesh.devices.flatten().tolist(), + assigned_devices, + ) + self.assertEqual(created_mesh.axis_names, ("x", "y")) + + def test_allocate_named_mesh_device_slices_uses_jax_devices_by_default(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + + fake_devices = [FakeDevice(0), FakeDevice(1)] + + with mock.patch.object(mesh.jax, "devices", return_value=fake_devices): + allocated = mesh.allocate_named_mesh_device_slices([("trainer", 2)]) + + self.assertEqual([device.id for device in allocated["trainer"]], [0, 1]) + + def test_allocate_named_mesh_device_slices_uses_whole_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index, coords): + self.id = device_id + self.process_index = process_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 1, (0, 0, 1)), + FakeDevice(3, 1, (1, 0, 1)), + ] + + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + allocated = mesh.allocate_named_mesh_device_slices( + [("trainer", 2), ("rollout", 2)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["trainer"]], [0, 1]) + self.assertEqual([device.id for device in allocated["rollout"]], [2, 3]) + + def test_allocate_named_mesh_device_slices_allows_multiple_single_host_subslices(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (0, 1, 0), 0), + FakeDevice(3, (1, 1, 0), 0), + FakeDevice(4, (0, 2, 0), 0), + FakeDevice(5, (1, 2, 0), 0), + FakeDevice(6, (0, 3, 0), 0), + FakeDevice(7, (1, 3, 0), 0), + ] + + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 2), ("reference", 2), ("rollout", 2)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["actor"]], [0, 1]) + self.assertEqual([device.id for device in allocated["reference"]], [2, 3]) + self.assertEqual([device.id for device in allocated["rollout"]], [4, 5]) + + def test_allocate_named_mesh_device_slices_reuses_partial_host_leftovers(self): + class FakeDevice: + + def __init__(self, device_id, process_index, coords): + self.id = device_id + self.process_index = process_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 0, (0, 1, 0)), + FakeDevice(3, 0, (1, 1, 0)), + FakeDevice(4, 1, (0, 0, 1)), + FakeDevice(5, 1, (1, 0, 1)), + FakeDevice(6, 1, (0, 1, 1)), + FakeDevice(7, 1, (1, 1, 1)), + ] + + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + allocated = mesh.allocate_named_mesh_device_slices( + [("mesh1", 3), ("mesh2", 2), ("mesh3", 2), ("mesh4", 1)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["mesh1"]], [0, 1, 2]) + self.assertEqual([device.id for device in allocated["mesh2"]], [4, 5]) + self.assertEqual([device.id for device in allocated["mesh3"]], [6, 7]) + self.assertEqual([device.id for device in allocated["mesh4"]], [3]) + + def test_allocate_named_mesh_device_slices_reuses_valid_leftover_host_group(self): + class FakeDevice: + + def __init__(self, device_id, process_index, coords): + self.id = device_id + self.process_index = process_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 0, (0, 1, 0)), + FakeDevice(3, 0, (1, 1, 0)), + FakeDevice(4, 0, (0, 0, 1)), + FakeDevice(5, 0, (1, 0, 1)), + FakeDevice(6, 0, (0, 1, 1)), + FakeDevice(7, 0, (1, 1, 1)), + FakeDevice(8, 1, (0, 0, 2)), + FakeDevice(9, 1, (1, 0, 2)), + FakeDevice(10, 1, (0, 1, 2)), + FakeDevice(11, 1, (1, 1, 2)), + FakeDevice(12, 1, (0, 0, 3)), + FakeDevice(13, 1, (1, 0, 3)), + FakeDevice(14, 1, (0, 1, 3)), + FakeDevice(15, 1, (1, 1, 3)), + ] + + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + allocated = mesh.allocate_named_mesh_device_slices( + [("mesh1", 6), ("mesh2", 2)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["mesh1"]], [0, 1, 2, 3, 4, 5]) + self.assertEqual([device.id for device in allocated["mesh2"]], [6, 7]) + + def test_allocate_named_mesh_device_slices_allocates_full_host_then_remainder(self): + class FakeDevice: + + def __init__(self, device_id, process_index, coords): + self.id = device_id + self.process_index = process_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 1, (0, 0, 1)), + FakeDevice(3, 1, (1, 0, 1)), + ] + + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + allocated = mesh.allocate_named_mesh_device_slices( + [("trainer", 3)], + devices=fake_devices, + ) + + self.assertEqual([device.id for device in allocated["trainer"]], [0, 1, 2]) + + def test_allocate_named_mesh_device_slices_raises_without_host_bound_metadata(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ] + + with self.assertRaisesRegex( + ValueError, + "Host-group allocation requires an inferable host-bound shape and device count", + ): + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + mesh.allocate_named_mesh_device_slices( + [("trainer", 2)], + devices=fake_devices, + ) + + def test_allocate_named_mesh_device_slices_raises_when_not_enough_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index, coords): + self.id = device_id + self.process_index = process_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 1, (0, 0, 1)), + FakeDevice(3, 1, (1, 0, 1)), + ] + + with self.assertRaisesRegex(ValueError, "but only 2 are available"): + with mock.patch.object(mesh, "allocate_devices_by_coords", return_value=None): + mesh.allocate_named_mesh_device_slices( + [("trainer", 6)], + devices=fake_devices, + ) + + def test_allocate_named_mesh_device_slices_raises_when_not_enough_devices(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + + fake_devices = [FakeDevice(0), FakeDevice(1)] + + with self.assertRaisesRegex(ValueError, "but only 2 remain available"): + mesh.allocate_named_mesh_device_slices( + [("trainer", 3)], + devices=fake_devices, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/utils/topology_test.py b/tests/utils/topology_test.py new file mode 100644 index 000000000..db8ffe25b --- /dev/null +++ b/tests/utils/topology_test.py @@ -0,0 +1,85 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from tunix.utils import topology + + +class TopologyTest(absltest.TestCase): + + def test_normalize_device_kind_recognizes_supported_families(self): + self.assertEqual(topology._normalize_device_kind("TPU v7"), "tpu7x") + self.assertEqual(topology._normalize_device_kind("TPU v6e"), "v6e") + self.assertEqual(topology._normalize_device_kind("TPU v5e"), "v5e") + self.assertEqual(topology._normalize_device_kind("TPU v5p"), "v5p") + self.assertEqual(topology._normalize_device_kind("TPU v4"), "v4") + self.assertIsNone(topology._normalize_device_kind("gpu")) + + def test_infer_chips_per_host_bounds_returns_none_for_empty_devices(self): + self.assertIsNone(topology.infer_chips_per_host_bounds([])) + + def test_infer_chips_per_host_bounds_returns_none_for_missing_device_kind(self): + class FakeDevice: + pass + + self.assertIsNone(topology.infer_chips_per_host_bounds([FakeDevice()])) + + def test_infer_chips_per_host_bounds_uses_single_host_shapes(self): + class FakeDevice: + + def __init__(self, device_kind): + self.device_kind = device_kind + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v5e")]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v6e")]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v7"), FakeDevice("TPU v7")]), + (1, 1, 1), + ) + + def test_infer_chips_per_host_bounds_uses_multi_host_shape_otherwise(self): + class FakeDevice: + + def __init__(self, device_kind): + self.device_kind = device_kind + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v7") for _ in range(4)]), + (2, 2, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v4") for _ in range(8)]), + (2, 2, 1), + ) + + def test_infer_chips_per_host_bounds_handles_callable_device_kind(self): + class FakeDevice: + + def device_kind(self): + return "TPU v7" + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice() for _ in range(128)]), + (2, 2, 1), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/cli/config.py b/tunix/cli/config.py index 9d2ae3a62..582820a8a 100644 --- a/tunix/cli/config.py +++ b/tunix/cli/config.py @@ -691,51 +691,6 @@ def _parse_mesh_config(self, model_key: str) -> tuple[tuple[int, ...], tuple[str ) return tuple(axis_shapes), tuple(axis_names) - def create_mesh(self, model_key: str, devices: Sequence[Any] | None = None): - """Validate and extract mesh configuration from a dictionary. - - Expects raw_keys to contain a 'mesh' key, which is a dictionary with 'shape' - and 'axis_names' keys. - - Args: - model_key: A model key that contain raw mesh configuration. For example, - in rl, there are actor_model, critic_model and reference_model, each of - them could have different mesh configuration. - devices: Optional explicit device subset to use for the mesh. When - provided, the mesh shape must exactly match the number of assigned - devices. - - Returns: - A tuple containing (axis_shapes, axis_names), both as tuples. - - Raises: - ValueError: If the mesh configuration is missing, malformed, or invalid. - """ - - axis_shapes, axis_names = self._parse_mesh_config(model_key) - num_devices = len(devices) if devices is not None else jax.device_count() - if np.prod(axis_shapes) > num_devices: - raise ValueError( - f"Mesh shape {axis_shapes} requires {np.prod(axis_shapes)} devices, " - f"but found {num_devices}." - ) - if devices is not None: - if np.prod(axis_shapes) != num_devices: - raise ValueError( - f"Mesh shape {axis_shapes} requires {np.prod(axis_shapes)} devices, " - f"but was assigned {num_devices}." - ) - return jax.sharding.Mesh( - np.array(list(devices)).reshape(axis_shapes), - axis_names, - axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), - ) - return jax.make_mesh( - axis_shapes, - axis_names, - axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), - ) - def obtain_training_config_dict(self, key): """Obtain training config dictionary from specified key in self.config. diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index d248ccd49..13a614288 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -52,6 +52,7 @@ from tunix.perf.experimental import export as perf_export_v2 from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout +from tunix.utils import mesh as mesh_lib _PATHWAYS_BNS = flags.DEFINE_string( @@ -189,7 +190,13 @@ def resolve_owner( role_to_owner[role] = resolve_owner(role, set()) return role_to_owner - def _create_role_to_mesh(self): + def create_role_to_mesh(self): + """Build role→mesh mapping. + + Any role with an explicit ``*.mesh`` config gets a dedicated device slice. + Roles without a mesh share the actor mesh by default, or can point at + another role via ``same_mesh_as``. + """ devices = list(jax.devices()) role_to_owner = self._resolve_mesh_owners() owner_order = [] @@ -200,50 +207,26 @@ def _create_role_to_mesh(self): if owner not in owner_order: owner_order.append(owner) - owner_to_mesh = {} - owner_to_device_slice = {} - device_offset = 0 + mesh_requirements = [] for owner in owner_order: model_key = self._ROLE_TO_MODEL_KEY[owner] axis_shapes, _ = self._parse_mesh_config(model_key) - required_devices = int(np.prod(axis_shapes)) - next_offset = device_offset + required_devices - if next_offset > len(devices): - raise ValueError( - f"Mesh allocation requires {next_offset} devices after allocating" - f" {model_key}, but only {len(devices)} are available." - ) - assigned_devices = devices[device_offset:next_offset] - owner_to_device_slice[owner] = assigned_devices - owner_to_mesh[owner] = self.create_mesh( - model_key, devices=assigned_devices - ) - device_offset = next_offset + mesh_requirements.append((model_key, int(np.prod(axis_shapes)))) - if device_offset < len(devices): - logging.warning( - "Mesh allocation used %d of %d devices; %d devices remain unused.", - device_offset, - len(devices), - len(devices) - device_offset, - ) - logging.info( - "Mesh device allocation: %s", - { - self._ROLE_TO_MODEL_KEY[owner]: len(owner_to_device_slice[owner]) - for owner in owner_order - }, + allocated_devices = mesh_lib.allocate_named_mesh_device_slices( + mesh_requirements, + devices=devices, ) - return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} - - def create_role_to_mesh(self): - """Build role→mesh mapping. - Any role with an explicit ``*.mesh`` config gets a dedicated device slice. - Roles without a mesh share the actor mesh by default, or can point at - another role via ``same_mesh_as``. - """ - return self._create_role_to_mesh() + owner_to_mesh = {} + for owner in owner_order: + model_key = self._ROLE_TO_MODEL_KEY[owner] + axis_shapes, axis_names = self._parse_mesh_config(model_key) + assigned_devices = allocated_devices[model_key] + owner_to_mesh[owner] = mesh_lib.create_mesh( + axis_shapes, axis_names, devices=assigned_devices + ) + return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} # ------------------------------------------------------------------ # Rollout config diff --git a/tunix/cli/peft_main.py b/tunix/cli/peft_main.py index e3724a7a4..334e73233 100644 --- a/tunix/cli/peft_main.py +++ b/tunix/cli/peft_main.py @@ -25,6 +25,7 @@ from tunix.examples.data import translation_dataset as data_lib from tunix.sft import peft_trainer from tunix.sft import utils +from tunix.utils import mesh as mesh_lib _PATHWAYS_BNS = flags.DEFINE_string( "pathways_bns", None, "BNS address of the Pathways server." @@ -36,7 +37,8 @@ class PeftPipeline(config.HyperParameters): def run_peft_trainer(self): """Run the PEFT trainer.""" - mesh: jax.sharding.Mesh = self.create_mesh('model_config') + axis_shapes, axis_names = self._parse_mesh_config('model_config') + mesh: jax.sharding.Mesh = mesh_lib.create_mesh(axis_shapes, axis_names) model: nnx.Module | None = None tokenizer: Any | None = None my_gen_model_input_fn: ( diff --git a/tunix/utils/mesh.py b/tunix/utils/mesh.py new file mode 100644 index 000000000..54dff7b49 --- /dev/null +++ b/tunix/utils/mesh.py @@ -0,0 +1,1338 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared mesh device allocation helpers. + +Typical usage: + + allocations = allocate_named_mesh_device_slices([ + ("actor", 8), + ("rollout", 4), + ]) + +The keys are arbitrary mesh names chosen by the caller. The integer is the +number of devices that mesh should receive. +""" + +import collections +import dataclasses +from typing import Any, Sequence + +from absl import logging +import jax +import numpy as np +from tunix.utils import topology + +MeshRequirement = tuple[str, int] + + +def create_mesh( + axis_shapes: tuple[int, ...], + axis_names: tuple[str, ...], + devices: Sequence[Any] | None = None, +): + """Builds a JAX mesh from parsed axis metadata.""" + if len(axis_shapes) != len(axis_names): + raise ValueError( + f"mesh.shape {axis_shapes} and mesh.axis_names {axis_names} " + "must have the same length." + ) + + num_devices = len(devices) if devices is not None else jax.device_count() + required_devices = int(np.prod(axis_shapes)) + if required_devices > num_devices: + raise ValueError( + f"Mesh shape {axis_shapes} requires {required_devices} devices, " + f"but found {num_devices}." + ) + if devices is not None: + if required_devices != num_devices: + raise ValueError( + f"Mesh shape {axis_shapes} requires {required_devices} devices, " + f"but was assigned {num_devices}." + ) + return jax.sharding.Mesh( + np.array(list(devices)).reshape(axis_shapes), + axis_names, + axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), + ) + return jax.make_mesh( + axis_shapes, + axis_names, + axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), + ) + + +@dataclasses.dataclass(frozen=True) +class CoordTopology: + """Normalized coord metadata for a device pool. + + Attributes: + coord_to_device: Mapping from physical coords to device objects. + all_coords: Normalized coord tuples for all devices. + num_dims: Number of coord dimensions. + max_shape: Bounding-box shape of the device pool. + """ + + coord_to_device: dict[tuple[int, ...], Any] + all_coords: tuple[tuple[int, ...], ...] + num_dims: int + max_shape: tuple[int, ...] + chip_coord_to_coords: dict[tuple[int, ...], tuple[tuple[int, ...], ...]] + + +@dataclasses.dataclass(frozen=True) +class DeviceAllocationState: + """Tracks the remaining device pool across sequential mesh allocations. + + This state object exists so `allocate_devices()` can be the lowest-level + public API while still supporting multi-mesh allocation. Callers that only + need one mesh can pass `devices=` directly to `allocate_devices()`. Callers + that need multiple meshes can create state once and repeatedly allocate from + it, which is exactly what `allocate_named_mesh_device_slices()` does. + + Attributes: + remaining_devices: Flat view of devices that have not yet been assigned. + remaining_host_groups: Optional per-host buckets used by the host-aware + fallback path. This becomes `None` once allocation falls back to a purely + flat remaining-device pool. + full_devices_per_host: Original per-host capacity derived from host groups. + host_bound_shape: Per-host physical topology shape, such as `(2, 2, 1)`. + host_bound_device_count: Device count implied by `host_bound_shape`. + total_device_count: Size of the original device pool. + used_device_count: Number of devices already assigned. + """ + + remaining_devices: tuple[Any, ...] + remaining_host_groups: tuple[tuple[Any, ...], ...] | None + full_devices_per_host: int + host_bound_shape: tuple[int, ...] | None + host_bound_device_count: int | None + total_device_count: int + used_device_count: int = 0 + + +def device_attr(device: Any, attr_name: str) -> Any: + """Returns a raw device attribute, calling it first if JAX exposes it lazily. + + Args: + device: A JAX device or test double. + attr_name: Attribute name such as "coords" or "process_index". + + Returns: + The attribute value, or None if the attribute does not exist. + """ + value = getattr(device, attr_name, None) + return value() if callable(value) else value + + +def device_host_key(device: Any) -> tuple[Any, ...] | None: + """Returns a stable host grouping key for topology-aware allocation. + + Args: + device: A JAX device or test double. + + Returns: + A tuple of (slice_id, task_id) when that metadata is available, otherwise + None. + """ + task_id = None + for attr_name in ("logical_task", "task_id", "process_index"): + task_id = device_attr(device, attr_name) + if task_id is not None: + break + if task_id is None: + return None + + slice_id = None + for attr_name in ("slice_index", "slice"): + slice_id = device_attr(device, attr_name) + if slice_id is not None: + break + return (slice_id, task_id) + + +def device_slice_id(device: Any) -> Any: + """Returns the slice identifier when the runtime exposes one. + + This is intentionally narrower than `device_host_key()`: it captures only the + slice boundary, not the host/task within that slice. Slice-aware allocation + uses this to prefer satisfying a mesh from one slice before spilling into the + next slice. + """ + for attr_name in ("slice_index", "slice"): + slice_id = device_attr(device, attr_name) + if slice_id is not None: + return slice_id + return None + + +def device_mesh_coords(device: Any) -> tuple[int, ...] | None: + """Returns physical mesh coordinates for topology-aware allocation. + + Args: + device: A JAX device or test double. + + Returns: + A tuple like (x, y, z) or (x, y, z, core) when the runtime exposes device + coordinates, otherwise None. + """ + coords = device_attr(device, "coords") + if coords is None: + return None + + coords = tuple(coords) + if not coords: + return None + + normalized_coords = tuple(int(coord) for coord in coords) + core_on_chip = device_attr(device, "core_on_chip") + if core_on_chip is None: + return normalized_coords + return normalized_coords + (int(core_on_chip),) + + +def infer_core_on_chip_count(devices: Sequence[Any]) -> int | None: + """Returns the per-chip core count when the runtime exposes it consistently.""" + chip_to_cores = collections.defaultdict(set) + saw_any_core = False + + for device in devices: + coords = device_attr(device, "coords") + core_on_chip = device_attr(device, "core_on_chip") + if coords is None: + return None + if core_on_chip is None: + continue + saw_any_core = True + chip_to_cores[tuple(int(coord) for coord in coords)].add(int(core_on_chip)) + + if not saw_any_core: + return None + + core_counts = {len(core_ids) for core_ids in chip_to_cores.values()} + if len(core_counts) != 1: + return None + return next(iter(core_counts)) + + +def summarize_devices_for_logging(devices: Sequence[Any]) -> list[dict[str, Any]]: + """Builds compact log-friendly summaries for a device list. + + Args: + devices: Devices to summarize. + + Returns: + A list of dictionaries containing device id, coords, and inferred host key. + """ + summaries = [] + for device in devices: + summaries.append({ + "id": device_attr(device, "id"), + "coords": device_mesh_coords(device), + "host": device_host_key(device), + }) + return summaries + + +def summarize_devices_for_debug_logging( + devices: Sequence[Any], + limit: int = 16, +) -> list[dict[str, Any]]: + """Builds richer device summaries for topology debugging. + + Args: + devices: Devices to summarize. + limit: Maximum number of devices to include. + + Returns: + A list of dictionaries with raw device topology metadata. + """ + summaries = [] + for device in devices[:limit]: + summaries.append({ + "id": device_attr(device, "id"), + "coords": device_attr(device, "coords"), + "core_on_chip": device_attr(device, "core_on_chip"), + "process_index": device_attr(device, "process_index"), + "logical_task": device_attr(device, "logical_task"), + "task_id": device_attr(device, "task_id"), + "slice_index": device_attr(device, "slice_index"), + "slice": device_attr(device, "slice"), + "host": device_host_key(device), + }) + return summaries + + +def summarize_host_groups_for_logging(devices: Sequence[Any]) -> dict[tuple[Any, ...], int]: + """Summarizes device counts per derived host key for debug logging.""" + host_counts = collections.Counter() + for device in devices: + host_key = device_host_key(device) + host_counts[host_key] += 1 + return dict(sorted(host_counts.items(), key=lambda item: str(item[0]))) + + +def group_devices_by_slice(devices: Sequence[Any]) -> list[list[Any]] | None: + """Groups devices by slice while preserving first-seen slice order. + + Returns `None` when slice metadata is unavailable for any device. The order of + groups matches the first appearance of each slice in `devices`, which lets the + allocator prefer earlier slices before spilling into later ones. + """ + slice_to_devices = {} + for device in devices: + slice_id = device_slice_id(device) + if slice_id is None: + return None + slice_to_devices.setdefault(slice_id, []).append(device) + return list(slice_to_devices.values()) + + +def group_devices_by_host(devices: Sequence[Any]) -> list[list[Any]] | None: + """Groups devices by host/task when that metadata is available. + + Args: + devices: Candidate devices to partition. + + Returns: + A list of equal-sized per-host device lists, or None if host metadata is + missing or inconsistent. + """ + host_to_devices = {} + for device in devices: + host_key = device_host_key(device) + if host_key is None: + return None + host_to_devices.setdefault(host_key, []).append(device) + + host_sizes = {len(host_devices) for host_devices in host_to_devices.values()} + if len(host_sizes) != 1: + logging.warning( + "Falling back to flat device allocation because host sizes differ: %s", + sorted(host_sizes), + ) + return None + return list(host_to_devices.values()) + + +def host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Returns the per-host physical box shape when coords are available. + + Args: + devices: Devices spanning one or more hosts. + + Returns: + The shape of one host in physical coords, such as (2, 2, 1), or None when + it cannot be inferred reliably. + """ + host_to_coords = collections.defaultdict(list) + for device in devices: + host_key = device_host_key(device) + coords = device_mesh_coords(device) + if host_key is None or coords is None: + return None + host_to_coords[host_key].append(coords) + + host_shapes = set() + for coords_list in host_to_coords.values(): + ndim = len(coords_list[0]) + mins = tuple(min(coords[i] for coords in coords_list) for i in range(ndim)) + maxs = tuple(max(coords[i] for coords in coords_list) for i in range(ndim)) + shape = tuple(max_coord - min_coord + 1 for min_coord, max_coord in zip(mins, maxs)) + if int(np.prod(shape)) != len(coords_list): + return None + host_shapes.add(shape) + + if len(host_shapes) != 1: + return None + return next(iter(host_shapes)) + + +def get_coord_topology(devices: Sequence[Any]) -> CoordTopology | None: + """Builds normalized coord metadata for a device pool. + + Args: + devices: Candidate devices to inspect. + + Returns: + A CoordTopology describing the device coords and overall bounding box, or + None when the devices do not expose a consistent coord layout. + """ + if not devices: + return None + + coord_to_device = {} + all_coords = [] + for device in devices: + coords = device_mesh_coords(device) + if coords is None: + logging.info( + "Coord topology unavailable because device lacks coords: %s", + summarize_devices_for_debug_logging([device]), + ) + return None + if all_coords and len(coords) != len(all_coords[0]): + logging.info( + "Coord topology unavailable because coord rank differs: existing_rank=%d device=%s", + len(all_coords[0]), + summarize_devices_for_debug_logging([device]), + ) + return None + if coords in coord_to_device: + logging.info( + "Coord topology unavailable because multiple devices share coords %s: %s", + coords, + summarize_devices_for_debug_logging([coord_to_device[coords], device]), + ) + return None + coord_to_device[coords] = device + all_coords.append(coords) + + num_dims = len(all_coords[0]) + chip_coord_to_coords = collections.defaultdict(list) + for coords in all_coords: + chip_coord_to_coords[coords[:-1]].append(coords) + max_shape = tuple( + max(coords[dim] for coords in all_coords) + - min(coords[dim] for coords in all_coords) + + 1 + for dim in range(num_dims) + ) + return CoordTopology( + coord_to_device=coord_to_device, + all_coords=tuple(all_coords), + num_dims=num_dims, + max_shape=max_shape, + chip_coord_to_coords={ + chip_coord: tuple(sorted(group_coords)) + for chip_coord, group_coords in chip_coord_to_coords.items() + }, + ) + + +def candidate_uses_whole_chips( + coord_topology: CoordTopology, + candidate_coords: Sequence[tuple[int, ...]], +) -> bool: + """Returns whether a candidate includes all logical devices for each chip. + + When multiple logical devices share the same physical chip coordinates, a + valid Pathways subslice must include either all of them or none of them. + This rejects candidates that split `core_on_chip` siblings across meshes. + """ + if coord_topology.num_dims <= 1: + return True + + selected_coords = set(candidate_coords) + selected_chip_coords = {coords[:-1] for coords in selected_coords} + for chip_coord in selected_chip_coords: + chip_group = coord_topology.chip_coord_to_coords.get(chip_coord, ()) + if any(coords not in selected_coords for coords in chip_group): + return False + return True + + +def known_host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Returns known host bounds from static topology metadata when available. + + Args: + devices: Devices from a single TPU slice. + + Returns: + A known per-host physical bound such as (1, 1, 1) or (2, 2, 1), or None if + the accelerator family is unknown. + """ + bounds = topology.infer_chips_per_host_bounds(devices) + if bounds is None: + return None + + coords = device_mesh_coords(devices[0]) if devices else None + if coords is None: + return None + + if len(coords) == len(bounds): + return bounds + + if len(coords) == len(bounds) + 1: + core_count = infer_core_on_chip_count(devices) + if core_count is None: + return None + return bounds + (core_count,) + + return None + + +def resolve_per_host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Resolves per-host shape and validates inferred vs known topology. + + Args: + devices: Devices spanning one or more hosts. + + Returns: + The inferred per-host shape when available, otherwise the known static host + bounds. + + Raises: + ValueError: If runtime-inferred host shape disagrees with known static host + bounds for the device family. + """ + inferred_shape = host_mesh_shape(devices) + static_shape = known_host_mesh_shape(devices) + if ( + inferred_shape is not None + and static_shape is not None + and inferred_shape != static_shape + ): + raise ValueError( + "Inferred per-host device shape " + f"{inferred_shape} does not match known host bounds {static_shape}." + ) + return inferred_shape or static_shape + + +def _divisors(value: int) -> list[int]: + divisors = set() + for candidate in range(1, int(np.sqrt(value)) + 1): + if value % candidate == 0: + divisors.add(candidate) + divisors.add(value // candidate) + return sorted(divisors) + + +def _enumerate_box_shapes( + required_devices: int, + max_shape: tuple[int, ...], +) -> list[tuple[int, ...]]: + """Enumerates box shapes whose volume matches the requested device count.""" + shapes = [] + num_dims = len(max_shape) + + def build(dim_index: int, remaining: int, prefix: tuple[int, ...]): + if dim_index == num_dims - 1: + if remaining <= max_shape[dim_index]: + shapes.append(prefix + (remaining,)) + return + + for size in _divisors(remaining): + if size > max_shape[dim_index]: + continue + build(dim_index + 1, remaining // size, prefix + (size,)) + + build(0, required_devices, ()) + return shapes + + +def _coord_box_score( + start: tuple[int, ...], + shape: tuple[int, ...], + host_shape: tuple[int, ...] | None, +) -> tuple[Any, ...]: + """Builds a lexicographic sort key for candidate coord boxes. + + The returned tuple is ordered so Python tuple comparison implements the + desired ranking policy directly: + + 1. Prefer host-aligned boxes when host_shape is known. + 2. Prefer boxes with a smaller maximum dimension. + 3. Prefer more compact overall shapes. + 4. Prefer earlier start coordinates as a stable tiebreaker. + + Args: + start: Candidate box origin. + shape: Candidate box shape. + host_shape: Per-host physical shape such as (2, 2, 1). + + Returns: + A tuple sort key suitable for lexicographic comparison. + """ + chip_host_alignment = 1 + full_host_alignment = 1 + if host_shape is not None: + chip_dims = min(3, len(shape), len(host_shape)) + chip_aligned = all( + start[dim] % host_shape[dim] == 0 + and shape[dim] % host_shape[dim] == 0 + for dim in range(chip_dims) + if host_shape[dim] > 1 + ) + fully_aligned = all( + start[dim] % host_shape[dim] == 0 + and shape[dim] % host_shape[dim] == 0 + for dim in range(len(shape)) + if host_shape[dim] > 1 + ) + chip_host_alignment = 0 if chip_aligned else 1 + full_host_alignment = 0 if fully_aligned else 1 + return ( + chip_host_alignment, + full_host_alignment, + max(shape), + tuple(sorted(shape, reverse=True)), + tuple(-dim for dim in shape), + start, + ) + + +def select_best_candidate_coords( + candidate_boxes: Sequence[ + tuple[tuple[int, ...], tuple[int, ...], Sequence[tuple[int, ...]]] + ], + host_shape: tuple[int, ...] | None, +) -> list[tuple[int, ...]] | None: + """Selects the best candidate coord box using the mesh heuristic. + + Args: + candidate_boxes: Sequence of (start, shape, candidate_coords) tuples. + `start` is the box origin, `shape` is the physical box shape, and + `candidate_coords` are the device coords inside that box. + host_shape: Per-host physical shape such as (2, 2, 1), used to prefer + host-aligned boxes when available. + + Returns: + The candidate coord list for the best-ranked box, or None when there are no + candidates. + + Notes: + Candidate boxes are ranked by `_coord_box_score()`, which uses a + lexicographic sort key instead of a single numeric score. This makes the + priority order explicit and avoids arbitrary weighting between ranking + factors. + """ + best_candidate_coords = None + best_score = None + for start, shape, candidate_coords in candidate_boxes: + score = _coord_box_score(start, shape, host_shape) + if best_score is None or score < best_score: + best_score = score + best_candidate_coords = list(candidate_coords) + return best_candidate_coords + + +def find_candidate_coord_boxes( + coord_topology: CoordTopology, + required_devices: int, +) -> list[tuple[tuple[int, ...], tuple[int, ...], tuple[tuple[int, ...], ...]]]: + """Finds contiguous candidate coord boxes for a requested device count. + + Args: + coord_topology: Normalized coord metadata for the candidate device pool. + required_devices: Number of devices needed for one mesh. + + Returns: + A list of (start, shape, candidate_coords) tuples representing contiguous + coord boxes whose volume matches required_devices. + + Notes: + This function only enumerates valid contiguous boxes that exist in the + current device pool. It does not choose among them; ranking is handled by + `select_best_candidate_coords()`. + """ + candidate_boxes = [] + for shape in _enumerate_box_shapes(required_devices, coord_topology.max_shape): + for start in coord_topology.coord_to_device: + candidate_coords = [] + for offset in np.ndindex(shape): + candidate_coord = tuple( + start[dim] + offset[dim] for dim in range(coord_topology.num_dims) + ) + if candidate_coord not in coord_topology.coord_to_device: + break + candidate_coords.append(candidate_coord) + else: + if candidate_uses_whole_chips(coord_topology, candidate_coords): + candidate_boxes.append((start, shape, tuple(candidate_coords))) + return candidate_boxes + + +def find_host_aligned_candidate_coord_boxes( + coord_topology: CoordTopology, + required_devices: int, + host_shape: tuple[int, ...], +) -> list[tuple[tuple[int, ...], tuple[int, ...], tuple[tuple[int, ...], ...]]]: + """Finds contiguous candidate boxes that exactly respect host bounds. + + Args: + coord_topology: Normalized coord metadata for the candidate device pool. + required_devices: Number of devices needed for one mesh. + host_shape: Known per-host physical shape such as (2, 2, 1) or + (2, 2, 1, 2). + + Returns: + A list of valid coord boxes whose shape is an exact multiple of host_shape. + """ + if len(host_shape) != coord_topology.num_dims: + return [] + + host_volume = int(np.prod(host_shape)) + if host_volume <= 0 or required_devices % host_volume != 0: + return [] + + host_grid_shape = tuple( + coord_topology.max_shape[dim] // host_shape[dim] + for dim in range(coord_topology.num_dims) + ) + required_host_boxes = required_devices // host_volume + + candidate_boxes = [] + for host_box_shape in _enumerate_box_shapes(required_host_boxes, host_grid_shape): + physical_shape = tuple( + host_box_shape[dim] * host_shape[dim] + for dim in range(coord_topology.num_dims) + ) + for start in coord_topology.coord_to_device: + if any( + start[dim] % host_shape[dim] != 0 + for dim in range(coord_topology.num_dims) + if host_shape[dim] > 1 + ): + continue + + candidate_coords = [] + for offset in np.ndindex(physical_shape): + candidate_coord = tuple( + start[dim] + offset[dim] for dim in range(coord_topology.num_dims) + ) + if candidate_coord not in coord_topology.coord_to_device: + break + candidate_coords.append(candidate_coord) + else: + if candidate_uses_whole_chips(coord_topology, candidate_coords): + candidate_boxes.append((start, physical_shape, tuple(candidate_coords))) + return candidate_boxes + + +def allocate_devices_by_coords( + devices: Sequence[Any], + required_devices: int, +) -> list[Any] | None: + """Allocates a contiguous physical box of devices when coords exist. + + Args: + devices: Candidate devices to allocate from. + required_devices: Number of devices needed for one mesh. + + Returns: + A list of devices forming the best contiguous physical box, or None if the + devices do not expose usable coordinates. + + Notes: + This helper runs in three stages: + + 1. Build normalized coord metadata with `get_coord_topology()`. + 2. Enumerate valid contiguous candidate boxes with + `find_candidate_coord_boxes()`. + 3. Rank those candidates with `select_best_candidate_coords()` and map the + winning coords back to device objects. + """ + coord_topology = get_coord_topology(devices) + if coord_topology is None: + return None + per_host_shape = resolve_per_host_mesh_shape(devices) + + candidate_boxes = [] + if per_host_shape is not None: + candidate_boxes = find_host_aligned_candidate_coord_boxes( + coord_topology, + required_devices, + per_host_shape, + ) + if not candidate_boxes: + candidate_boxes = find_candidate_coord_boxes(coord_topology, required_devices) + + best_candidate_coords = select_best_candidate_coords( + candidate_boxes, + per_host_shape, + ) + if best_candidate_coords is None: + return None + + selected_coords = set(best_candidate_coords) + return [ + device + for device in devices + if device_mesh_coords(device) in selected_coords + ] + + +def _create_device_allocation_state( + devices: Sequence[Any] | None = None, + *, + log_summary: bool = True, +) -> DeviceAllocationState: + """Builds reusable allocator state for one or more mesh allocations. + + This is intentionally private because callers should not need to understand + the allocator internals to request one mesh. The public entry point is + `allocate_devices()`, which accepts either raw `devices` for one-shot use or + an existing `allocation_state` for incremental use. + """ + all_devices = tuple(jax.devices() if devices is None else devices) + if log_summary: + logging.info( + "Mesh allocator raw device sample: %s", + summarize_devices_for_debug_logging(all_devices), + ) + logging.info( + "Mesh allocator derived host groups: %s", + summarize_host_groups_for_logging(all_devices), + ) + remaining_host_groups = group_devices_by_host(all_devices) + full_devices_per_host = ( + len(remaining_host_groups[0]) if remaining_host_groups else 0 + ) + host_bound_shape = _infer_host_bound_shape(all_devices) + host_bound_device_count = _infer_host_bound_device_count( + host_bound_shape, + full_devices_per_host, + ) + if remaining_host_groups and ( + host_bound_shape is None or not host_bound_device_count + ): + raise ValueError( + "Host-group allocation requires an inferable host-bound shape and " + "device count." + ) + return DeviceAllocationState( + remaining_devices=all_devices, + remaining_host_groups=( + tuple(tuple(group) for group in remaining_host_groups) + if remaining_host_groups + else None + ), + full_devices_per_host=full_devices_per_host, + host_bound_shape=host_bound_shape, + host_bound_device_count=host_bound_device_count, + total_device_count=len(all_devices), + ) + + +def _allocate_devices_from_pool( + required_devices: int, + remaining_devices: list[Any], + remaining_host_groups: list[list[Any]] | None, + full_devices_per_host: int, + host_bound_shape: tuple[int, ...] | None, + host_bound_device_count: int | None, + mesh_name: str, +) -> tuple[list[Any], list[Any], list[list[Any]] | None]: + """Allocates one mesh from a concrete device pool without slice policy. + + This helper contains the pool-local allocation strategy used after any + slice-level decision has already been made. + """ + assigned_devices = allocate_devices_by_coords(remaining_devices, required_devices) + if assigned_devices is not None: + remaining_devices = _remove_devices_by_identity( + remaining_devices, + assigned_devices, + ) + remaining_host_groups = None + return assigned_devices, remaining_devices, remaining_host_groups + + if remaining_host_groups: + assigned_devices, remaining_host_groups = _allocate_from_host_groups( + remaining_host_groups, + required_devices, + full_devices_per_host, + host_bound_shape, + host_bound_device_count or 0, + mesh_name, + ) + remaining_devices = _remove_devices_by_identity( + remaining_devices, + assigned_devices, + ) + return assigned_devices, remaining_devices, remaining_host_groups + + if required_devices > len(remaining_devices): + raise ValueError( + f"Mesh allocation requires {required_devices} devices for {mesh_name}, " + f"but only {len(remaining_devices)} remain available." + ) + assigned_devices = remaining_devices[:required_devices] + remaining_devices = remaining_devices[required_devices:] + return assigned_devices, remaining_devices, remaining_host_groups + + +def allocate_devices( + required_devices: int, + devices: Sequence[Any] | None = None, + *, + mesh_name: str = "allocated_mesh", + allocation_state: DeviceAllocationState | None = None, + return_state: bool = False, +) -> list[Any] | tuple[list[Any], DeviceAllocationState]: + """Allocates devices for a single mesh request. + + This is the lowest-level public allocation API. It handles exactly one mesh + request and applies the allocator policy in priority order: + + 1. Prefer a contiguous coord-aligned box when device coords are available. + 2. Otherwise, use host-aware allocation without illegally breaking host + topology. + 3. Otherwise, fall back to a flat prefix of the remaining devices. + + There are two intended calling modes: + + 1. One-shot allocation: pass `devices=` and receive a single allocation. + 2. Incremental allocation: pass `allocation_state=` and, when + `return_state=True`, receive the updated remaining pool for the next call. + + `allocate_named_mesh_device_slices()` is implemented as a thin loop around + this function. + + Args: + required_devices: Number of devices to allocate for this mesh. + devices: Raw device pool for one-shot use. Mutually exclusive with + `allocation_state`. + mesh_name: Name used only for diagnostics and error messages. + allocation_state: Existing state for incremental allocation. + return_state: Whether to return the updated allocation state alongside the + assigned devices. + + Returns: + Either the assigned device list, or `(assigned_devices, next_state)` when + `return_state=True`. + + Raises: + ValueError: If both `devices` and `allocation_state` are provided, or if + the request cannot be satisfied from the remaining device pool. + """ + if devices is not None and allocation_state is not None: + raise ValueError( + "Pass either devices or allocation_state to allocate_devices, not both." + ) + + owns_state = allocation_state is None + state = allocation_state or _create_device_allocation_state(devices) + remaining_devices = list(state.remaining_devices) + remaining_host_groups = ( + [list(group) for group in state.remaining_host_groups] + if state.remaining_host_groups + else None + ) + assigned_devices = None + + slice_groups = group_devices_by_slice(remaining_devices) + if slice_groups and len(slice_groups) > 1: + # Prefer staying within one slice when a single slice can satisfy the whole + # request. This avoids accidental cross-slice meshes when slice metadata is + # available. + for slice_devices in slice_groups: + if len(slice_devices) < required_devices: + continue + slice_state = _create_device_allocation_state( + slice_devices, + log_summary=False, + ) + assigned_devices, _, _ = _allocate_devices_from_pool( + required_devices, + list(slice_state.remaining_devices), + ( + [list(group) for group in slice_state.remaining_host_groups] + if slice_state.remaining_host_groups + else None + ), + slice_state.full_devices_per_host, + slice_state.host_bound_shape, + slice_state.host_bound_device_count, + mesh_name, + ) + remaining_devices = _remove_devices_by_identity( + remaining_devices, + assigned_devices, + ) + remaining_host_groups = None + break + + # If no single slice is large enough, consume slices in order. This makes a + # cross-slice mesh grow by exhausting one slice before spilling into the + # next one. + if assigned_devices is None and len(remaining_devices) >= required_devices: + slice_order = [device_slice_id(group[0]) for group in slice_groups] + assigned_devices = [] + remaining_required = required_devices + for slice_id in slice_order: + if remaining_required == 0: + break + current_slice_devices = [ + device for device in remaining_devices if device_slice_id(device) == slice_id + ] + if not current_slice_devices: + continue + slice_request = min(remaining_required, len(current_slice_devices)) + slice_state = _create_device_allocation_state( + current_slice_devices, + log_summary=False, + ) + partial_devices, _, _ = _allocate_devices_from_pool( + slice_request, + list(slice_state.remaining_devices), + ( + [list(group) for group in slice_state.remaining_host_groups] + if slice_state.remaining_host_groups + else None + ), + slice_state.full_devices_per_host, + slice_state.host_bound_shape, + slice_state.host_bound_device_count, + mesh_name, + ) + assigned_devices.extend(partial_devices) + remaining_devices = _remove_devices_by_identity( + remaining_devices, + partial_devices, + ) + remaining_required -= len(partial_devices) + remaining_host_groups = None + + if assigned_devices is None: + assigned_devices, remaining_devices, remaining_host_groups = _allocate_devices_from_pool( + required_devices, + remaining_devices, + remaining_host_groups, + state.full_devices_per_host, + state.host_bound_shape, + state.host_bound_device_count, + mesh_name, + ) + + next_state = dataclasses.replace( + state, + remaining_devices=tuple(remaining_devices), + remaining_host_groups=( + tuple(tuple(group) for group in remaining_host_groups) + if remaining_host_groups + else None + ), + used_device_count=state.used_device_count + len(assigned_devices), + ) + logging.info( + "Allocated devices for %s: %s", + mesh_name, + summarize_devices_for_logging(assigned_devices), + ) + + if owns_state and not return_state: + unused_device_count = next_state.total_device_count - next_state.used_device_count + if unused_device_count > 0: + logging.warning( + "Mesh allocation used %d of %d devices; %d devices remain unused.", + next_state.used_device_count, + next_state.total_device_count, + unused_device_count, + ) + + if return_state: + return assigned_devices, next_state + return assigned_devices + + +def _remove_devices_by_identity( + devices: Sequence[Any], + assigned_devices: Sequence[Any], +) -> list[Any]: + assigned_device_ids = {id(device) for device in assigned_devices} + return [device for device in devices if id(device) not in assigned_device_ids] + + +def _infer_host_bound_device_count( + host_bound_shape: tuple[int, ...] | None, + full_devices_per_host: int, +) -> int | None: + """Infers the smallest host-aligned device-count unit when possible.""" + if host_bound_shape is None or full_devices_per_host <= 0: + return None + + host_bound_device_count = int(np.prod(host_bound_shape)) + if host_bound_device_count <= 0: + return None + if host_bound_device_count > full_devices_per_host: + return None + return host_bound_device_count + + +def _infer_host_bound_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Infers the per-host bound shape when topology metadata is available.""" + return resolve_per_host_mesh_shape(devices) + + +def _allocate_devices_within_host_group( + host_devices: Sequence[Any], + required_devices: int, + host_bound_shape: tuple[int, ...], +) -> list[Any] | None: + """Allocates devices from one host bucket using coord-aware selection.""" + coord_topology = get_coord_topology(host_devices) + if coord_topology is None: + return None + + candidate_boxes = find_host_aligned_candidate_coord_boxes( + coord_topology, + required_devices, + host_bound_shape, + ) + if not candidate_boxes: + candidate_boxes = find_candidate_coord_boxes(coord_topology, required_devices) + + best_candidate_coords = select_best_candidate_coords( + candidate_boxes, + host_bound_shape, + ) + if best_candidate_coords is None: + return None + + selected_coords = set(best_candidate_coords) + return [ + device + for device in host_devices + if device_mesh_coords(device) in selected_coords + ] + + +def _satisfies_host_bound_shape( + host_devices: Sequence[Any], + host_bound_shape: tuple[int, ...] | None, + host_bound_device_count: int, +) -> bool: + if host_bound_shape is None or host_bound_device_count <= 0: + raise ValueError( + "host_bound_shape and host_bound_device_count must be set for " + "host-group allocation." + ) + return ( + _allocate_devices_within_host_group( + host_devices, + len(host_devices), + host_bound_shape, + ) + is not None + ) + + +def _allocate_partial_host_group( + host_groups: Sequence[Sequence[Any]], + required_devices: int, + host_bound_shape: tuple[int, ...], + host_bound_device_count: int, + mesh_name: str, +) -> tuple[list[Any], list[list[Any]] | None] | None: + """Allocates a request from one host bucket if a compatible bucket exists. + + This helper deliberately does not merge fragments from different hosts. The + policy is to satisfy a partial request from exactly one remaining host bucket + and to keep the leftover from that same bucket only if the leftover still + forms a host-valid shape. If taking the prefix would leave an invalid host + fragment behind, this host bucket is skipped and the allocator tries the next + one. + """ + for host_index, host_devices in enumerate(host_groups): + if len(host_devices) < required_devices: + continue + if not _satisfies_host_bound_shape( + host_devices, + host_bound_shape, + host_bound_device_count, + ): + logging.info( + "Skipping remaining host group for %s because %d devices do not " + "satisfy inferred host-bound shape %s.", + mesh_name, + len(host_devices), + host_bound_shape, + ) + continue + assigned_devices = list(host_devices[:required_devices]) + remaining_devices_for_host = list(host_devices[required_devices:]) + if remaining_devices_for_host and not _satisfies_host_bound_shape( + remaining_devices_for_host, + host_bound_shape, + host_bound_device_count, + ): + logging.info( + "Skipping remaining host group for %s because taking %d devices would " + "leave %d devices that do not satisfy host-bound shape %s.", + mesh_name, + required_devices, + len(remaining_devices_for_host), + host_bound_shape, + ) + continue + remaining_host_groups = [list(group) for group in host_groups] + if remaining_devices_for_host: + remaining_host_groups[host_index] = remaining_devices_for_host + else: + del remaining_host_groups[host_index] + return assigned_devices, remaining_host_groups or None + return None + + +def _allocate_from_host_groups( + host_groups: Sequence[Sequence[Any]], + required_devices: int, + full_devices_per_host: int, + host_bound_shape: tuple[int, ...] | None, + host_bound_device_count: int, + mesh_name: str, +) -> tuple[list[Any], list[list[Any]] | None]: + """Allocates from remaining per-host buckets while preserving leftovers. + + This path is used only after coord allocation fails. + + Why this exists: + + 1. Some environments expose enough host metadata to preserve host boundaries + even when a full coord topology is unavailable or unsuitable. + 2. We want to allow partial-host reuse across meshes, but only when the + leftover fragment still has a valid host-bounded shape. + 3. We do not want to silently assemble one logical "host" out of unrelated + fragments taken from multiple different hosts. + + Policy: + + 1. Allocate the whole-host portion first when `required_devices` spans one or + more full hosts. + 2. Allocate any remainder from exactly one remaining host bucket. + 3. Reject the request if that remainder cannot be taken without leaving an + invalid host fragment behind. + """ + if host_bound_shape is None or host_bound_device_count <= 0: + raise ValueError( + "Host-group allocation requires an inferable host-bound shape and " + "device count." + ) + if full_devices_per_host <= 0: + raise ValueError( + "Host-group allocation requires a positive full host device count." + ) + + remaining_host_groups = [list(group) for group in host_groups] + assigned_devices = [] + + required_full_hosts = required_devices // full_devices_per_host + remainder_devices = required_devices % full_devices_per_host + + if required_full_hosts: + full_host_indices = [ + index + for index, host_devices in enumerate(remaining_host_groups) + if len(host_devices) == full_devices_per_host + ] + if required_full_hosts > len(full_host_indices): + raise ValueError( + f"Mesh allocation requires {required_full_hosts} hosts for {mesh_name}, " + f"but only {len(full_host_indices)} are available." + ) + + selected_host_indices = set(full_host_indices[:required_full_hosts]) + assigned_devices.extend([ + device + for index, host_devices in enumerate(remaining_host_groups) + if index in selected_host_indices + for device in host_devices + ]) + remaining_host_groups = [ + list(host_devices) + for index, host_devices in enumerate(remaining_host_groups) + if index not in selected_host_indices + ] + + if remainder_devices: + partial_allocation = _allocate_partial_host_group( + remaining_host_groups, + remainder_devices, + host_bound_shape, + host_bound_device_count, + mesh_name, + ) + if partial_allocation is None: + raise ValueError( + f"Mesh allocation for {mesh_name} requires {required_devices} devices, " + f"but no remaining host group can satisfy the remaining {remainder_devices} devices." + ) + partial_devices, remaining_host_groups = partial_allocation + assigned_devices.extend(partial_devices) + + if not assigned_devices: + partial_allocation = _allocate_partial_host_group( + remaining_host_groups, + required_devices, + host_bound_shape, + host_bound_device_count, + mesh_name, + ) + if partial_allocation is None: + raise ValueError( + f"Mesh allocation for {mesh_name} requires {required_devices} devices, " + "but no remaining host has enough capacity to satisfy that request." + ) + assigned_devices, remaining_host_groups = partial_allocation + + return assigned_devices, remaining_host_groups or None + + +def allocate_named_mesh_device_slices( + mesh_requirements: Sequence[MeshRequirement], + devices: Sequence[Any] | None = None, +) -> dict[str, list[Any]]: + """Allocates device subsets for named meshes. + + This is a convenience wrapper over `allocate_devices()` for callers that want + several named allocations from one shared device pool. + + The function builds one `DeviceAllocationState`, then calls + `allocate_devices()` once per `(mesh_name, required_devices)` pair. That + keeps the single-mesh allocation policy centralized in one public API instead + of duplicating decision logic here. + + Args: + mesh_requirements: Sequence of (mesh_name, required_devices) pairs. + Example: [("actor", 8), ("rollout", 4)]. The mesh_name is only used for + logging and as the key in the returned dictionary. + devices: Optional explicit device list. When omitted, this uses + jax.devices(). + + Returns: + A dictionary mapping each mesh name to the list of devices assigned to it. + + Raises: + ValueError: If a requested mesh cannot be assigned enough devices or if a + host-based allocation would split hosts illegally. + """ + state = _create_device_allocation_state(devices) + allocations = {} + + for mesh_name, required_devices in mesh_requirements: + assigned_devices, state = allocate_devices( + required_devices, + mesh_name=mesh_name, + allocation_state=state, + return_state=True, + ) + allocations[mesh_name] = assigned_devices + + unused_device_count = state.total_device_count - state.used_device_count + if unused_device_count > 0: + logging.warning( + "Mesh allocation used %d of %d devices; %d devices remain unused.", + state.used_device_count, + state.total_device_count, + unused_device_count, + ) + logging.info( + "Mesh device allocation: %s", + {mesh_name: len(assigned_devices) for mesh_name, assigned_devices in allocations.items()}, + ) + return allocations diff --git a/tunix/utils/topology.py b/tunix/utils/topology.py new file mode 100644 index 000000000..3820e0416 --- /dev/null +++ b/tunix/utils/topology.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal accelerator topology helpers used by Tunix mesh allocation.""" + +from typing import Any, Sequence + +_SINGLE_HOST_BOUNDS = (1, 1, 1) +_MULTI_HOST_BOUNDS = (2, 2, 1) + + +def _device_attr(device: Any, attr_name: str) -> Any: + """Returns a raw device attribute, calling it first when exposed lazily.""" + value = getattr(device, attr_name, None) + return value() if callable(value) else value + + +def _normalize_device_kind(device_kind: str) -> str | None: + device_kind = device_kind.lower() + if "v7" in device_kind: + return "tpu7x" + if "v6e" in device_kind or "v6" in device_kind: + return "v6e" + if "v5e" in device_kind: + return "v5e" + if "v5" in device_kind: + return "v5p" + if "v4" in device_kind: + return "v4" + return None + + +def infer_chips_per_host_bounds( + devices: Sequence[Any], +) -> tuple[int, ...] | None: + if not devices: + return None + + device_kind = _device_attr(devices[0], "device_kind") + if not isinstance(device_kind, str): + return None + + family = _normalize_device_kind(device_kind) + if family is None: + return None + + device_count = len(devices) + if family in {"v5e", "v6e"} and device_count == 1: + return _SINGLE_HOST_BOUNDS + if family == "tpu7x" and device_count == 2: + return _SINGLE_HOST_BOUNDS + return _MULTI_HOST_BOUNDS