Add communication-free Muon for FlexShard#3502
Draft
weifengpy wants to merge 13 commits into
Draft
Conversation
This was referenced Jun 3, 2026
weifengpy
added a commit
that referenced
this pull request
Jun 3, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: fe8ef01 Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 3, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 17456cf Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 651b804 Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 2cbbe74 Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: e70eda4 Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 7274abe Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 151c99c Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 9, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 5a171d5 Pull-Request: #3502
weifengpy
added a commit
that referenced
this pull request
Jun 10, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute). Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt). Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix. Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py ghstack-source-id: 702ec18 Pull-Request: #3502
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).
Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).
Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.
Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py