Add Qwen3 AutoParallel model and examples#482
Open
AlbedoWang wants to merge 1 commit into
Open
Conversation
6d0e1e5 to
15da60f
Compare
There was a problem hiding this comment.
Pull request overview
Adds a Qwen3 reference implementation (dense + MoE) into the repo’s _testing/models suite to exercise AutoParallel’s tracing/solver on modern transformer + MoE patterns (including a local_map-wrapped expert region), along with runnable examples and unit tests. Also includes a small compatibility tweak in the existing DeepSeek-V3 model config plumbing for TorchTitan configs that omit use_grouped_mm.
Changes:
- Add
autoparallel/_testing/models/qwen3.py: Qwen3 Transformer with QK-norm attention, RoPE via precomputed cos/sin cache, optional weight tying, and MoE expert dispatch wrapped inlocal_map. - Add end-to-end example scripts for fake-PG smoke runs and real distributed sanity checks (dense + MoE), plus a TorchTitan dense Qwen3 example.
- Add unit tests covering Qwen3 basics, TorchTitan parity checks (skipped when sibling checkout is absent), and AutoParallel smoke pipelines; add a DSv3 TorchTitan config compatibility test; add a DSv3 config guard for missing
use_grouped_mm.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
autoparallel/_testing/models/qwen3.py |
New Qwen3 reference model (dense + MoE) designed to be traceable/optimizable by AutoParallel, including local_map expert region. |
autoparallel/_testing/models/dsv3.py |
Makes MoE config parsing tolerant of TorchTitan configs that don’t define experts.use_grouped_mm. |
examples/example_qwen3.py |
Fake-PG example that traces/optimizes/applies Qwen3 (dense or MoE) and optionally runs fwd/bwd. |
examples/example_sanity_check_qwen3.py |
Real-GPU distributed sanity training loop for Qwen3 8B under AutoParallel. |
examples/example_sanity_check_qwen3_moe.py |
Real-GPU distributed sanity training loop for Qwen3 MoE under AutoParallel (EP mesh) with chunked vocab-parallel loss. |
examples/example_torchtitan_qwen3_dense.py |
Runs TorchTitan’s dense Qwen3 through AutoParallel placement on real GPUs. |
tests/test_qwen3.py |
Unit tests for Qwen3 forward shape, RoPE parity, debug args parity, and AutoParallel smoke tests (dense + MoE). |
tests/test_dsv3_torchtitan_config.py |
Verifies DSv3 accepts a TorchTitan grouped-experts config (skips without sibling TorchTitan checkout). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+704
to
+712
| # Annotate as plain tensors: parameters() yields Parameter, but | ||
| # _to_activation_device returns Tensor, and we reassign in place. | ||
| experts_w1: torch.Tensor | ||
| experts_w2: torch.Tensor | ||
| experts_w3: torch.Tensor | ||
| experts_w1, experts_w2, experts_w3 = self.experts.parameters() | ||
| experts_w1 = _to_activation_device(experts_w1, x) | ||
| experts_w2 = _to_activation_device(experts_w2, x) | ||
| experts_w3 = _to_activation_device(experts_w3, x) |
Comment on lines
+31
to
+33
| _add_sibling_torchtitan_to_path() | ||
|
|
||
| from torchtitan.models.qwen3 import Qwen3Model, qwen3_configs # noqa: E402 |
Comment on lines
+307
to
+310
| torch.manual_seed(args.seed) | ||
| model_args = make_model_args(args.flavor, args.seq_len) | ||
| if args.seq_len is None: | ||
| args.seq_len = model_args.max_seq_len |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a Qwen3 reference model (dense + MoE) under
autoparallel/_testing/models/, runnable examples, and unit tests. The MoE variant wraps expert dispatch inlocal_mapso the solver treats the expert computation as a single sharded node.What's added
autoparallel/_testing/models/qwen3.py— Qwen3Transformer(QK-norm attention, RoPE via a precomputed cos/sin buffer, optional weight tying, MoE block withlocal_map-wrapped experts),Qwen3ModelArgs, debug configs, andqwen3_args_from_torchtitan_configfor parity with torchtitan.examples/example_qwen3.py,example_sanity_check_qwen3.py,example_sanity_check_qwen3_moe.py,example_torchtitan_qwen3_dense.py— end-to-end AutoParallel runs (trace → optimize → apply → forward/backward).tests/test_qwen3.py,tests/test_dsv3_torchtitan_config.py.dsv3.py: one-line guard (getattr(..., "use_grouped_mm", True)) so the shared MoE config plumbing tolerates configs without that field.How to Test
python -m pytest tests/(uses fake PG, no GPU needed)python examples/example_qwen3.py(uses fake PG, single GPU for meta-device ops)Success Criteria
example_qwen3.pyruns successfully end-to-end (tracing, optimization, apply sharding, forward + backward)local_mapwraps expert dispatch correctly and the solver handles thelocal_mapnode (covered bytest_qwen3_moe_auto_parallel_smoke)python -m pytest tests/) — qwen3/dsv3 suites: 7 passed, 4 skipped (the skips depend on a torchtitan sibling checkout)Test coverage (
test_qwen3.py)Forward shape, QK-norm effect, weight-tying survival through
init_weights, dense/MoE shape parity with torchtitan debug args, torchtitan config parsing (skips without sibling checkout), RoPE cos/sin parity, dense AutoParallel pipeline smoke, and MoEauto_parallelsmoke (param count/shapes + forward/backward).