[alignment-megatron] PR7A: SGLang Ulysses CP attention and RoPE#27
Draft
maocheng23 wants to merge 3 commits into
Draft
[alignment-megatron] PR7A: SGLang Ulysses CP attention and RoPE#27maocheng23 wants to merge 3 commits into
maocheng23 wants to merge 3 commits into
Conversation
Add a clean Megatron backend that calls SGLang-compatible math under a flag: - sglang.py: SGLangLinear, SGLangRMSNorm, SGLangFlashAttention and related modules - matmul_tp_inv.py: TP-invariant matmul dispatch for Megatron layers - transformer_config.py: use_sglang config flag - arguments.py: --use-sglang CLI arg - layers.py: conditional SGLang backend selection in TP layers - gpt_layer_specs.py: SGLang-compatible layer spec builder - test_sglang_extension.py: import, config, and default-path-unchanged tests Default training path remains unchanged when use_sglang is off. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match SGLang's TP reduction order and full-vocab logprob contract: - mappings.py: tree_all_reduce_sum for deterministic TP reduction - layers.py: conditional tree allreduce in RowParallelLinear - gpt_model.py: full-vocab logprob gather/truncate/log-softmax - transformer_config.py: true_on_policy_logits config - test_tree_all_reduce.py: TP tree-allreduce tests - test_true_on_policy_logits.py: full-vocab gather/truncate tests Default NCCL allreduce path unchanged when flags are off. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Implements PR 7A, stacked on PR 7. Adds the Megatron-side half of
Ulysses (cp_comm_type=a2a) context-parallel support for the SGLang
backend, matching the Miles-side Ulysses CP contract (miles PR 11A).
Non-Ulysses paths are unchanged: when cp_comm_type is not "a2a" (or
cp_size <= 1), SGLangCoreAttention falls back to DotProductAttention
and RoPE continues to use the existing cp_size/cp_rank from the CP
group.
Changes:
- megatron/core/extensions/sglang.py
- SGLangFlashAttention: FA3 varlen attention that slices heads
across the CP group on input and all-gathers heads on output
(Ulysses a2a over the head axis). Handles both packed (thd) and
unpacked inputs, BF16 cast, GQA repeat, and FA3 kwarg
compatibility via inspect.signature.
- SGLangCoreAttention: thin dispatcher that picks
SGLangFlashAttention when config.context_parallel_size > 1 and
cp_comm_type == "a2a", otherwise DotProductAttention. Preserves
the current_max_attn_logits plumbing expected by Megatron.
- SGLangSpecProvider.core_attention(): now returns
SGLangCoreAttention.
- megatron/core/models/common/embeddings/rope_utils.py
- _apply_rotary_pos_emb_thd and apply_rotary_pos_emb take a new
ulysses_cp flag. When set, RoPE uses cp_size=1/cp_rank=0 so the
freqs tensor is applied to the full unsplit sequence (Ulysses
keeps the full sequence on each rank; sharding happens inside
attention only).
- megatron/core/transformer/attention.py
- _is_ulysses_cp helper that reads cp_comm_type from config
(handles both str and list[str] forms) and threads ulysses_cp
into the RoPE call sites for both Q and K.
Tests:
- tests/unit_tests/extension/test_sglang_ulysses_cp.py
- _is_ulysses_cp detects a2a, including the list form, and rejects
cp_size==1 or non-a2a cp_comm_type.
- apply_rotary_pos_emb with ulysses_cp=True produces the same
output as an unsplit (cp_size=1) RoPE pass.
- SGLangCoreAttention dispatches to SGLangFlashAttention under
Ulysses and falls back to DotProductAttention otherwise.
Compatibility:
- cp_comm_type defaulting behavior and non-Ulysses CP paths are
untouched: SGLangCoreAttention resolves to DotProductAttention
when not in Ulysses mode, and RoPE defaults ulysses_cp=False, so
existing cp_size/cp_rank splitting is preserved.
- Flash Attention 3's flash_attn_varlen_func is only imported/used
on the Ulysses path; non-Ulysses runs do not require FA3.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements PR 7A — the Megatron-side half of Ulysses (
cp_comm_type="a2a") context-parallel support for the SGLang backend, pairing with Miles PR 11A. Stacked on PR 7 (#26).Non-Ulysses paths are unchanged: when
cp_comm_typeis not"a2a"(orcp_size <= 1),SGLangCoreAttentionfalls back toDotProductAttentionand RoPE continues to use the existingcp_size/cp_rankfrom the CP group.Contract satisfied
Contract H (Ulysses CP-aware training data, logprob, and loss) on the Megatron scoring path:
SGLangFlashAttention) that mirrors SGLang's inference layout, so scoring and rollout agree when Miles routes true-on-policy through this backend.Files changed
megatron/core/extensions/sglang.pySGLangFlashAttention: FA3 varlen attention for Ulysses. Slices heads across the CP group on input, callsflash_attn_varlen_func, all-gathers heads on output. Handles packed (thd) and unpacked inputs, BF16 cast, GQA repeat-interleave, and FA3 kwarg-compat viainspect.signature.SGLangCoreAttention: thin dispatcher — picksSGLangFlashAttentioniffconfig.context_parallel_size > 1andcp_comm_type == "a2a", otherwiseDotProductAttention. Preservescurrent_max_attn_logitsplumbing that Megatron expects.SGLangSpecProvider.core_attention(): now returnsSGLangCoreAttention. The dispatcher defaults back toDotProductAttentionoutside Ulysses, so this is a no-op for non-Ulysses runs.megatron/core/models/common/embeddings/rope_utils.py_apply_rotary_pos_emb_thdandapply_rotary_pos_embtake a newulysses_cp: boolflag (defaultsFalse).ulysses_cp=True, RoPE usescp_size=1/cp_rank=0, applying freqs to the full unsplit sequence — matching SGLang's sequence layout.megatron/core/transformer/attention.py_is_ulysses_cp(config)helper readscp_comm_typefrom config, handling bothstrandlist[str]forms and thecp_size <= 1cases.ulysses_cpinto both Q and Kapply_rotary_pos_embcall sites in thesplit_qkvpath.Tests
tests/unit_tests/extension/test_sglang_ulysses_cp.pytest_is_ulysses_cp_detects_a2a_mode— coverscp_size=1+a2a(false),cp_size=2+p2p(false),cp_size=2+a2a(true), and thelist[str]form["a2a", "p2p"](true).test_apply_rotary_pos_emb_ulysses_matches_unsplit_sequence— Ulysses RoPE output is identical to an unsplit (cp_size=1) RoPE pass.test_sglang_core_attention_dispatches_ulysses_backend—SGLangCoreAttentionpicks the Ulysses impl undercp_size=2+a2a.test_sglang_core_attention_falls_back_outside_ulysses— fallback toDotProductAttentionundercp_size=2+p2p.Compatibility
cp_comm_type != "a2a"orcp_size <= 1):SGLangCoreAttentionroutes toDotProductAttention, and RoPE defaults toulysses_cp=False(preserving existingcp_size/cp_ranksplitting). No behavior change.SGLangSpecProvideris not selected, so none of these paths execute.flash_attn_varlen_funcis only imported/used insideSGLangFlashAttention.forward. Non-Ulysses runs do not require FA3.Known performance impact
Pre-commit
black: pass.isort: pass (PR 7A files only).pylint: the file inherits preexistingmissing-function-docstringwarnings from PR 6'sSGLangSpecProvider; no newline-too-longor other regressions are introduced by this PR. Rating moved from 9.85 → 9.89.Dependencies
miles-mainwith the PR diff containing PR 6 + PR 7 + PR 7A commits; reviewers should focus on the latest commit.Test plan
pytest tests/unit_tests/extension/test_sglang_ulysses_cp.py -von CPU (no GPU required for the three included tests)tests/unit_tests/extension/test_sglang_extension.pypasses unchangedmiles_migration.md)🤖 Generated with Claude Code