[torchtitan] fix: pass attn_backend to model_registry() and sync with engine_config.attn_type#21
Draft
[torchtitan] fix: pass attn_backend to model_registry() and sync with engine_config.attn_type#21
Conversation
…attn_type Bug: TorchtitanEngineConfig.attn_type defaults to 'flex', but model_registry() defaults to 'sdpa'. verl never passed attn_backend to model_registry(), so every torchtitan user silently ran sdpa even when their config said flex or varlen. Root causes: 1. model_registry(flavor) called without attn_backend parameter → always uses the default 'sdpa' 2. The fallback override used model_spec.model.layer (singular) but torchtitan configs use .layers (plural list), so the guard always failed and the override never executed 3. A second read site (line 599) also used the wrong attribute path self.trainer.model_config.layer.attention.attn_backend Fix: - Pass attn_backend=engine_config.attn_type to model_registry() - Read attn_type from self.engine_config.attn_type (authoritative source) - Remove the broken post-hoc override entirely Affected models: llama3, qwen3, deepseek_v3, llama4 (all models) Affected hardware: CUDA, NPU, XPU (all platforms) Tested with: torchtitan 0.2.2 (HEAD), NVIDIA A100 80GB Evidence: - evidence/evidence_attn_backend_bug.log: reproduction proof - evidence/evidence_test_results.log: 17/17 tests pass - tests/unit/torchtitan/test_attn_backend_sync.py: unit tests
…oss torchtitan snapshots
torchtitan's model_registry() API changed between commits:
- Feb 2026 (9810191): model_registry(flavor) — no attn_backend param,
backend baked per flavor (e.g. 'debugmodel_flex_attn')
- Apr 2026 (7cec166+, current): model_registry(flavor, attn_backend='sdpa')
Use inspect.signature to detect which API is present, so both snapshots
work correctly. Emit a clear warning on the legacy path so users know
their attn_type setting could not be applied.
Also fix second bug site: read attn_type from self.engine_config.attn_type
instead of the broken self.trainer.model_config.layer.attention.attn_backend
path (singular 'layer' — always failed; correct path is 'layers', plural).
Tests: 18/18 pass, including a mock test for the legacy fallback path.
Four reviewer findings addressed:
HIGH - Test collection crash when no torchtitan models discovered:
fixture ids= lambda assumed tuple params; used NotSetType sentinel
on empty params causing TypeError before any test ran.
Fix: guard lambda with isinstance check so empty-param case is safe.
MEDIUM - Legacy fallback continued with mismatched model+mask backend:
On older torchtitan (no attn_backend param), code warned and continued.
But get_attention_masks() only supports flex/varlen; with a model built
as sdpa this produces silent training corruption.
Fix: raise RuntimeError for flex/varlen on legacy torchtitan. sdpa
with legacy API is fine (model+mask both take the sdpa/causal path).
LOW - Wrong flavor suffix in error message:
Suggested '{flavor}_flex' but legacy flavors use '_flex_attn' / '_varlen_attn'.
Fix: corrected to '{flavor}_flex_attn' and '{flavor}_varlen_attn'.
LOW - Unused 'patch' import in test file:
Fix: removed.
Also: moved 'import inspect' to top-level module imports instead of
inside the function body.
Two changes: 1. Fix 'sdpa is safe on legacy' edge case: use_remove_padding defaults to True independently of attn_type. When use_remove_padding=True, get_attention_masks() is always called, and it only supports flex/varlen — sdpa raises TypeError there. So sdpa on legacy torchtitan is also unsafe. Remove the conditional 'flex/varlen only' guard and raise RuntimeError for any attn_type when model_registry() lacks the attn_backend parameter. 2. Drop evidence/ run-artifact logs from the PR diff. These are not appropriate for an upstream source PR.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Problem
When using TorchTitanEngine with
attn_type = "flex"(the default inTorchtitanEngineConfig), the model was silently built with"sdpa"attention instead, causing a runtime mismatch and crash during training.
Two bugs compounded each other:
Bug 1 —
model_registry()called withoutattn_backendIn
TorchTitanEngine.__init__, the call was:Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.