Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ on:

permissions:
contents: read
jobs:
jobs:
run:
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -63,6 +63,10 @@ jobs:
run: |
python -m pytest tests/cli/utils/ -v --tb=short
- name: Run shared mesh and topology tests
run: |
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
- name: Run perf tests
run: |
python -m pytest tests/perf/ -v --tb=short
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ jobs:
- name: Run tunix tests not covered by the above categories
run: |
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
if [ "${code:-0}" = "5" ]; then
echo "No tests collected (expected)."
exit 0
Expand Down
80 changes: 46 additions & 34 deletions tests/cli/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from tunix.sft import peft_trainer
from tunix.tests import test_common as tc
from tunix.utils import env_utils
from tunix.utils import mesh as mesh_lib

os.environ.setdefault("HF_TOKEN", "TestToken")


class ConfigTest(parameterized.TestCase):
Expand Down Expand Up @@ -262,7 +265,7 @@ def test_learning_rate_schedule_valid(self, overrides):
self.assertIsNotNone(lr_schedule)
self.assertTrue(callable(lr_schedule), "lr_schedule should be callable")

# --- Tests for create_mesh ---
# --- Tests for mesh config parsing and mesh creation ---
@parameterized.named_parameters(
dict(
testcase_name="valid_1d",
Expand Down Expand Up @@ -311,40 +314,48 @@ def test_create_mesh_valid(
):
mock_device_count_fn.return_value = mock_num_devices
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
mesh = hp.create_mesh("model_config")
self.assertEqual(
mesh,
jax.make_mesh(
expected[0],
expected[1],
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
),
)
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
expected_mesh = object()

def test_create_mesh_with_assigned_devices(self):
raw_keys = {
"model_config": {
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
assigned_devices = ["d0", "d1", "d2", "d3"]
with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock:
mesh = mesh_lib.create_mesh(axis_shapes, axis_names)

class FakeMesh:
make_mesh_mock.assert_called_once_with(
expected[0],
expected[1],
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
)
self.assertIs(mesh, expected_mesh)

def test_create_mesh_with_assigned_devices(self):
raw_keys = {
"model_config": {
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
assigned_devices = ["d0", "d1", "d2", "d3"]

def __init__(self, devices, axis_names, axis_types=None):
self.devices = devices
self.axis_names = axis_names
self.axis_types = axis_types
class FakeMesh:

with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
mesh = hp.create_mesh("model_config", devices=assigned_devices)
def __init__(self, devices, axis_names, axis_types=None):
self.devices = devices
self.axis_names = axis_names
self.axis_types = axis_types

self.assertEqual(mesh.devices.shape, (2, 2))
self.assertSequenceEqual(
mesh.devices.flatten().tolist(), assigned_devices
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
mesh = mesh_lib.create_mesh(
axis_shapes,
axis_names,
devices=assigned_devices,
)
self.assertEqual(mesh.axis_names, ("x", "y"))

self.assertEqual(mesh.devices.shape, (2, 2))
self.assertSequenceEqual(
mesh.devices.flatten().tolist(), assigned_devices
)
self.assertEqual(mesh.axis_names, ("x", "y"))

@parameterized.named_parameters(
dict(
Expand Down Expand Up @@ -424,11 +435,12 @@ def test_create_mesh_invalid(
mock_num_devices,
error_regex,
):
mock_device_count_fn.return_value = mock_num_devices
with self.assertRaisesRegex(ValueError, error_regex):
nested_dict = self.convert_nested_dict_to_list(raw_keys)
hp = self.initialize_config(nested_dict)
hp.create_mesh("model_config")
mock_device_count_fn.return_value = mock_num_devices
with self.assertRaisesRegex(ValueError, error_regex):
nested_dict = self.convert_nested_dict_to_list(raw_keys)
hp = self.initialize_config(nested_dict)
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
mesh_lib.create_mesh(axis_shapes, axis_names)

@parameterized.named_parameters(
dict(
Expand Down
101 changes: 101 additions & 0 deletions tests/cli/grpo_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,107 @@ def __init__(self, devices, axis_names, axis_types=None):
role_to_mesh[rl_cluster_lib.Role.ACTOR],
)

def test_split_mesh_delegates_device_allocation_to_mesh_utils(self):
extra = """
training_mode: "agentic_grpo"
data_module: "tunix.cli.recipes.deepscaler_data"
apply_chat_template_to_dataset: false
data_config:
train_data_path: "gs://fake/train.json"
eval_data_path: "gs://fake/eval.parquet"
prompt_key: "prompts"
reward_functions: []
verl_compatible: false
chat_parser_config:
type: "default"
agent_class_path: null
agent_kwargs: {}
env_class_path: null
env_kwargs: {}
kubernetes_config: null
agentic_grpo_config:
num_generations: 2
num_iterations: 1
beta: 0.0
epsilon: 0.2
epsilon_high: 0.28
system_prompt: ""
max_concurrency: 1
off_policy_steps: 0
max_turns: 1
context_ratio: 1
sglang_jax_config:
mem_fraction_static: 0.8
vllm_config:
hbm_utilization: 0.4
"""
pipeline = _make_pipeline(extra)
actor_model_config = pipeline.config["actor_model_config"]
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
actor_model_config["mesh"] = {
"shape": "(1,2)",
"axis_names": "('fsdp','tp')",
}
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
rollout_model_config = pipeline.config["rollout_model_config"]
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
rollout_model_config["mesh"] = {
"shape": "(1,2)",
"axis_names": "('fsdp','tp')",
}

fake_devices = ["a0", "a1", "r0", "r1"]
allocated_devices = {
"actor_model_config": ["a0", "a1"],
"rollout_model_config": ["r0", "r1"],
}
created_mesh_devices = {}

def fake_create_mesh(axis_shapes, axis_names, devices=None):
model_key = (
"actor_model_config"
if axis_shapes == (1, 2) and axis_names == ("fsdp", "tp")
and "actor_model_config" not in created_mesh_devices
else "rollout_model_config"
)
created_mesh_devices[model_key] = list(devices)
return (model_key, tuple(devices))

with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
with mock.patch.object(
grpo_main.mesh_lib,
"allocate_named_mesh_device_slices",
return_value=allocated_devices,
) as allocate_mock:
with mock.patch.object(
grpo_main.mesh_lib,
"create_mesh",
side_effect=fake_create_mesh,
):
role_to_mesh = pipeline.create_role_to_mesh()

allocate_mock.assert_called_once_with(
[
("actor_model_config", 2),
("rollout_model_config", 2),
],
devices=fake_devices,
)
self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"])
self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"])
self.assertEqual(
role_to_mesh[rl_cluster_lib.Role.ACTOR],
("actor_model_config", ("a0", "a1")),
)
self.assertIs(
role_to_mesh[rl_cluster_lib.Role.REFERENCE],
role_to_mesh[rl_cluster_lib.Role.ACTOR],
)
self.assertEqual(
role_to_mesh[rl_cluster_lib.Role.ROLLOUT],
("rollout_model_config", ("r0", "r1")),
)


if __name__ == "__main__":
absltest.main()
Loading
Loading