Skip to content

Thread custom process groups through MoE grad finalization#19

Open
yashaswikarnati wants to merge 4 commits into
mainfrom
ykarnati/moe-custom-process-groups
Open

Thread custom process groups through MoE grad finalization#19
yashaswikarnati wants to merge 4 commits into
mainfrom
ykarnati/moe-custom-process-groups

Conversation

@yashaswikarnati
Copy link
Copy Markdown
Owner

Summary

  • use explicit MoE tensor-parallel groups for TE checkpointing and shared-expert TP collectives
  • allow router expert-bias finalization to reduce on an explicit TPxDPxCP process group
  • require explicit pg_collection.tp_dp_cp when expert-bias updates run with custom process groups
  • add focused coverage for group forwarding and early validation

Testing

  • cog batch job 11732315: python -m pytest tests/unit_tests/transformer/moe/test_routers.py::test_get_updated_expert_bias_uses_explicit_group tests/unit_tests/distributed/test_finalize_model_grads.py::test_update_router_expert_bias_uses_explicit_group tests/unit_tests/distributed/test_finalize_model_grads.py::test_finalize_model_grads_uses_pg_collection_tp_dp_cp tests/unit_tests/distributed/test_finalize_model_grads.py::test_finalize_model_grads_requires_tp_dp_cp_for_explicit_groups -q
  • pre-push hooks: black, pylint, isort

# TODO(Hepteract): delete the usage of the global parallel_state.
group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True),
)
tp_dp_cp_group = parallel_state.get_tensor_and_data_parallel_group(
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove TODO on 1183 ?

from tests.unit_tests.test_utilities import Utils


_MISSING = object()
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont like fake tests with monkey patches.. ci actually runs the test on 8 gpus.. is there a way of less verbose and actua testing

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 53f2fa5. Removed the monkeypatch-based finalize tests and replaced them with real distributed coverage: the tests initialize model parallel, build actual ProcessGroupCollection values, run finalize_model_grads, and verify router expert-bias update / early validation behavior. Verified on an 8-GPU cog job: 11732530.

os.environ.pop('NVTE_FLASH_ATTN', None)
os.environ.pop('NVTE_UNFUSED_ATTN', None)
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(tensor_model_parallel_size=min(2, Utils.world_size))
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

@yashaswikarnati yashaswikarnati force-pushed the ykarnati/moe-custom-process-groups branch 3 times, most recently from 7b4bb6b to 95aeafe Compare May 13, 2026 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant