Skip to content

Add communication-free Muon for FlexShard#3502

Draft
weifengpy wants to merge 13 commits into
gh/weifengpy/30/basefrom
gh/weifengpy/30/head
Draft

Add communication-free Muon for FlexShard#3502
weifengpy wants to merge 13 commits into
gh/weifengpy/30/basefrom
gh/weifengpy/30/head

Conversation

@weifengpy

@weifengpy weifengpy commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

[ghstack-poisoned]
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 3, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: fe8ef01
Pull-Request: #3502
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 3, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 17456cf
Pull-Request: #3502
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 651b804
Pull-Request: #3502
@weifengpy weifengpy marked this pull request as draft June 4, 2026 03:37
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 2cbbe74
Pull-Request: #3502
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: e70eda4
Pull-Request: #3502
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 7274abe
Pull-Request: #3502
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 4, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 151c99c
Pull-Request: #3502
weifengpy added 4 commits June 8, 2026 13:38
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 9, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 5a171d5
Pull-Request: #3502
[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Jun 10, 2026
Place each Muon-eligible 2D matrix whole on one rank via the Owned placement so Newton-Schulz runs locally on the owner after the backward reduce-to-owner -- no collective in optimizer.step(), and bit-exact with single-device Muon. Layers are balanced across ranks with greedy LPT; embeddings, LM head, and final norm stay Shard(0) + AdamW. Owned now composes with reshard_after_forward=True (broadcast ops tagged for activation-checkpoint recompute).

Adds example/muon.py (comm_free_muon_buckets, build_muon_param_groups, build_comm_free_muon_optimizers, CombinedOptimizer) and example/owned.py helpers (make_owned_placement_fn, assign_layer_owners_lpt).

Also adds GroupedMuon: batched Newton-Schulz over the leading dim of stacked weight matrices (>=3D, e.g. MoE grouped experts). build_comm_free_muon_optimizers routes 2D params to torch.optim.Muon and >=3D stacks to GroupedMuon; GroupedMuon matches running torch.optim.Muon on each 2D sub-matrix.

Tests: python -m pytest -q torchtitan/experiments/flex_shard/tests/test_flex_shard_muon.py

ghstack-source-id: 702ec18
Pull-Request: #3502
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant