Skip to content

Add deterministic topk for MoE routing#3600

Open
sanketpurandare wants to merge 1 commit into
mainfrom
sanketpurandare/stack/20
Open

Add deterministic topk for MoE routing#3600
sanketpurandare wants to merge 1 commit into
mainfrom
sanketpurandare/stack/20

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Route MoE expert selection through a TorchTitan custom op that enables PyTorch's deterministic topk implementation locally and restores the caller's deterministic-algorithm state afterward. This gives activation-checkpoint recompute the same stable top-k tie-breaking behavior without saving raw aten.topk outputs in selective activation checkpointing.

The wrapper follows the deterministic_scatter_add pattern, includes fake tensor and autograd registrations, and relies on the PyTorch deterministic topk implementation when available.

Test Plan:

  • python -m py_compile torchtitan/ops/topk.py torchtitan/models/common/moe.py torchtitan/distributed/activation_checkpoint.py tests/unit_tests/test_deterministic_ops.py

  • pytest tests/unit_tests/test_deterministic_ops.py -q

  • pytest tests/unit_tests/test_activation_checkpoint.py -q

  • pytest tests/unit_tests/test_compile_moe.py -q

  • pre-commit run --files torchtitan/ops/topk.py torchtitan/models/common/moe.py torchtitan/distributed/activation_checkpoint.py tests/unit_tests/test_deterministic_ops.py

  • pre-commit run --all-files

Route MoE expert selection through a TorchTitan custom op that enables PyTorch's deterministic topk implementation locally and restores the caller's deterministic-algorithm state afterward. This gives activation-checkpoint recompute the same stable top-k tie-breaking behavior without saving raw aten.topk outputs in selective activation checkpointing.

The wrapper follows the deterministic_scatter_add pattern, includes fake tensor and autograd registrations, and relies on the PyTorch deterministic topk implementation when available.

Test Plan:

- python -m py_compile torchtitan/ops/topk.py torchtitan/models/common/moe.py torchtitan/distributed/activation_checkpoint.py tests/unit_tests/test_deterministic_ops.py

- pytest tests/unit_tests/test_deterministic_ops.py -q

- pytest tests/unit_tests/test_activation_checkpoint.py -q

- pytest tests/unit_tests/test_compile_moe.py -q

- pre-commit run --files torchtitan/ops/topk.py torchtitan/models/common/moe.py torchtitan/distributed/activation_checkpoint.py tests/unit_tests/test_deterministic_ops.py

- pre-commit run --all-files

stack-info: PR: #3600, branch: sanketpurandare/stack/20
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/20 branch from c1033f9 to bc36ecf Compare June 10, 2026 01:42
torch.ops.aten.linear.default,
# topk can be non-deterministic; save to keep MoE expert assignments
# stable between forward and recompute.
torch.ops.aten.topk.default,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious -- what makes "always saving topk" in SAC policy bad?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not bad but unnecessary.

Comment thread torchtitan/ops/topk.py
) -> tuple[torch.Tensor, torch.Tensor]:
prev = torch.are_deterministic_algorithms_enabled()
prev_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
torch.use_deterministic_algorithms(True, warn_only=False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@songhappy does it break your use case?

@sanketpurandare sanketpurandare added the ciflow/h100.8 Trigger H100.8 CI label Jun 10, 2026
@sanketpurandare sanketpurandare marked this pull request as ready for review June 10, 2026 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100.8 Trigger H100.8 CI ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants