Skip to content

[alignment-megatron] PR7A: SGLang Ulysses CP attention and RoPE#27

Draft
maocheng23 wants to merge 3 commits into
radixark:miles-mainfrom
maocheng23:maocheng/top-pr07a-sglang-ulysses-cp
Draft

[alignment-megatron] PR7A: SGLang Ulysses CP attention and RoPE#27
maocheng23 wants to merge 3 commits into
radixark:miles-mainfrom
maocheng23:maocheng/top-pr07a-sglang-ulysses-cp

Conversation

@maocheng23
Copy link
Copy Markdown

Summary

Implements PR 7A — the Megatron-side half of Ulysses (cp_comm_type="a2a") context-parallel support for the SGLang backend, pairing with Miles PR 11A. Stacked on PR 7 (#26).

Non-Ulysses paths are unchanged: when cp_comm_type is not "a2a" (or cp_size <= 1), SGLangCoreAttention falls back to DotProductAttention and RoPE continues to use the existing cp_size/cp_rank from the CP group.

Contract satisfied

Contract H (Ulysses CP-aware training data, logprob, and loss) on the Megatron scoring path:

  • Under Ulysses, the full sequence is kept on every rank and sharding happens only over attention heads (a2a over the head axis).
  • SGLang-backend attention now has a Ulysses implementation (SGLangFlashAttention) that mirrors SGLang's inference layout, so scoring and rollout agree when Miles routes true-on-policy through this backend.
  • RoPE is applied to the full unsplit sequence under Ulysses so position encoding matches the rollout side.

Files changed

megatron/core/extensions/sglang.py

  • SGLangFlashAttention: FA3 varlen attention for Ulysses. Slices heads across the CP group on input, calls flash_attn_varlen_func, all-gathers heads on output. Handles packed (thd) and unpacked inputs, BF16 cast, GQA repeat-interleave, and FA3 kwarg-compat via inspect.signature.
  • SGLangCoreAttention: thin dispatcher — picks SGLangFlashAttention iff config.context_parallel_size > 1 and cp_comm_type == "a2a", otherwise DotProductAttention. Preserves current_max_attn_logits plumbing that Megatron expects.
  • SGLangSpecProvider.core_attention(): now returns SGLangCoreAttention. The dispatcher defaults back to DotProductAttention outside Ulysses, so this is a no-op for non-Ulysses runs.

megatron/core/models/common/embeddings/rope_utils.py

  • _apply_rotary_pos_emb_thd and apply_rotary_pos_emb take a new ulysses_cp: bool flag (defaults False).
  • When ulysses_cp=True, RoPE uses cp_size=1 / cp_rank=0, applying freqs to the full unsplit sequence — matching SGLang's sequence layout.

megatron/core/transformer/attention.py

  • _is_ulysses_cp(config) helper reads cp_comm_type from config, handling both str and list[str] forms and the cp_size <= 1 cases.
  • Threads ulysses_cp into both Q and K apply_rotary_pos_emb call sites in the split_qkv path.

Tests

tests/unit_tests/extension/test_sglang_ulysses_cp.py

  • test_is_ulysses_cp_detects_a2a_mode — covers cp_size=1 + a2a (false), cp_size=2 + p2p (false), cp_size=2 + a2a (true), and the list[str] form ["a2a", "p2p"] (true).
  • test_apply_rotary_pos_emb_ulysses_matches_unsplit_sequence — Ulysses RoPE output is identical to an unsplit (cp_size=1) RoPE pass.
  • test_sglang_core_attention_dispatches_ulysses_backendSGLangCoreAttention picks the Ulysses impl under cp_size=2 + a2a.
  • test_sglang_core_attention_falls_back_outside_ulysses — fallback to DotProductAttention under cp_size=2 + p2p.

Compatibility

  • Non-Ulysses (cp_comm_type != "a2a" or cp_size <= 1): SGLangCoreAttention routes to DotProductAttention, and RoPE defaults to ulysses_cp=False (preserving existing cp_size/cp_rank splitting). No behavior change.
  • Use-sglang off: SGLangSpecProvider is not selected, so none of these paths execute.
  • FA3 dependency: flash_attn_varlen_func is only imported/used inside SGLangFlashAttention.forward. Non-Ulysses runs do not require FA3.

Known performance impact

  • CP=1 / standard CP / use_sglang off: none (dispatcher resolves to existing paths; RoPE defaults preserve existing cp math).
  • Ulysses CP: adds an all-gather on attention output and a head-slice on input. Expected — this is how Ulysses is implemented in SGLang, and is what makes scoring numerically match rollout.

Pre-commit

  • black: pass.
  • isort: pass (PR 7A files only).
  • pylint: the file inherits preexisting missing-function-docstring warnings from PR 6's SGLangSpecProvider; no new line-too-long or other regressions are introduced by this PR. Rating moved from 9.85 → 9.89.

Dependencies

Test plan

  • pytest tests/unit_tests/extension/test_sglang_ulysses_cp.py -v on CPU (no GPU required for the three included tests)
  • Non-Ulysses regression: existing tests/unit_tests/extension/test_sglang_extension.py passes unchanged
  • GPU E2E: dense Qwen3 Ulysses CP exact-zero with Miles PR 11A once Milestone 1 E2E is ready (tracked under PR 17 in miles_migration.md)

🤖 Generated with Claude Code

maocheng23 and others added 3 commits April 21, 2026 15:52
Add a clean Megatron backend that calls SGLang-compatible math under a flag:
- sglang.py: SGLangLinear, SGLangRMSNorm, SGLangFlashAttention and related modules
- matmul_tp_inv.py: TP-invariant matmul dispatch for Megatron layers
- transformer_config.py: use_sglang config flag
- arguments.py: --use-sglang CLI arg
- layers.py: conditional SGLang backend selection in TP layers
- gpt_layer_specs.py: SGLang-compatible layer spec builder
- test_sglang_extension.py: import, config, and default-path-unchanged tests

Default training path remains unchanged when use_sglang is off.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match SGLang's TP reduction order and full-vocab logprob contract:
- mappings.py: tree_all_reduce_sum for deterministic TP reduction
- layers.py: conditional tree allreduce in RowParallelLinear
- gpt_model.py: full-vocab logprob gather/truncate/log-softmax
- transformer_config.py: true_on_policy_logits config
- test_tree_all_reduce.py: TP tree-allreduce tests
- test_true_on_policy_logits.py: full-vocab gather/truncate tests

Default NCCL allreduce path unchanged when flags are off.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Implements PR 7A, stacked on PR 7. Adds the Megatron-side half of
Ulysses (cp_comm_type=a2a) context-parallel support for the SGLang
backend, matching the Miles-side Ulysses CP contract (miles PR 11A).

Non-Ulysses paths are unchanged: when cp_comm_type is not "a2a" (or
cp_size <= 1), SGLangCoreAttention falls back to DotProductAttention
and RoPE continues to use the existing cp_size/cp_rank from the CP
group.

Changes:
- megatron/core/extensions/sglang.py
  - SGLangFlashAttention: FA3 varlen attention that slices heads
    across the CP group on input and all-gathers heads on output
    (Ulysses a2a over the head axis). Handles both packed (thd) and
    unpacked inputs, BF16 cast, GQA repeat, and FA3 kwarg
    compatibility via inspect.signature.
  - SGLangCoreAttention: thin dispatcher that picks
    SGLangFlashAttention when config.context_parallel_size > 1 and
    cp_comm_type == "a2a", otherwise DotProductAttention. Preserves
    the current_max_attn_logits plumbing expected by Megatron.
  - SGLangSpecProvider.core_attention(): now returns
    SGLangCoreAttention.
- megatron/core/models/common/embeddings/rope_utils.py
  - _apply_rotary_pos_emb_thd and apply_rotary_pos_emb take a new
    ulysses_cp flag. When set, RoPE uses cp_size=1/cp_rank=0 so the
    freqs tensor is applied to the full unsplit sequence (Ulysses
    keeps the full sequence on each rank; sharding happens inside
    attention only).
- megatron/core/transformer/attention.py
  - _is_ulysses_cp helper that reads cp_comm_type from config
    (handles both str and list[str] forms) and threads ulysses_cp
    into the RoPE call sites for both Q and K.

Tests:
- tests/unit_tests/extension/test_sglang_ulysses_cp.py
  - _is_ulysses_cp detects a2a, including the list form, and rejects
    cp_size==1 or non-a2a cp_comm_type.
  - apply_rotary_pos_emb with ulysses_cp=True produces the same
    output as an unsplit (cp_size=1) RoPE pass.
  - SGLangCoreAttention dispatches to SGLangFlashAttention under
    Ulysses and falls back to DotProductAttention otherwise.

Compatibility:
- cp_comm_type defaulting behavior and non-Ulysses CP paths are
  untouched: SGLangCoreAttention resolves to DotProductAttention
  when not in Ulysses mode, and RoPE defaults ulysses_cp=False, so
  existing cp_size/cp_rank splitting is preserved.
- Flash Attention 3's flash_attn_varlen_func is only imported/used
  on the Ulysses path; non-Ulysses runs do not require FA3.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant