Skip to content

[torchtitan] fix: pass attn_backend to model_registry() and sync with engine_config.attn_type#21

Draft
kahlun wants to merge 4 commits intomainfrom
fix/torchtitan-attn-backend-sync
Draft

[torchtitan] fix: pass attn_backend to model_registry() and sync with engine_config.attn_type#21
kahlun wants to merge 4 commits intomainfrom
fix/torchtitan-attn-backend-sync

Conversation

@kahlun
Copy link
Copy Markdown
Owner

@kahlun kahlun commented Apr 28, 2026

What does this PR do?

Problem

When using TorchTitanEngine with attn_type = "flex" (the default in
TorchtitanEngineConfig), the model was silently built with "sdpa"
attention instead, causing a runtime mismatch and crash during training.

Two bugs compounded each other:

Bug 1 — model_registry() called without attn_backend
In TorchTitanEngine.__init__, the call was:

model_spec = model_module.model_registry(torchtitan_flavor)

Since attn_backend was not passed, torchtitan used its own per-model default
("sdpa" for llama3/qwen3; "flex" only for llama4/MoE flavors), completely
ignoring engine_config.attn_type.

Bug 2override guard never triggered
The fallback that tried to patch the model spec afterward checked
hasattr(model_spec.model, "layer") (singular), but torchtitan uses "layers"
(plural). The guard was always Falsethe override body was dead code.

Result: model built with sdpa modules, but get_attention_masks() called
with attn_type="flex"crash or silently wrong behavior.Fix
Pass attn_backend=attn_type to model_registry() directly.
Use inspect.signature to detect whether the installed torchtitan supports
the attn_backend parameter (added in commit 7cec166, Apr 2026). On older
snapshots, raise RuntimeError immediately with a clear upgrade messagethe old path was never safe anyway (get_attention_masks() only supports
flex/varlen).
Fix the second call site in prepare_model_inputs to read
self.engine_config.attn_type (was incorrectly reading
self.trainer.model_config.layer.attention.attn_backend, a path that doesn't
exist).
Testing
Added tests/unit/torchtitan/test_attn_backend_sync.py (18 tests):

Verifies model_registry accepts attn_backend for all registered models
Verifies attn_type="flex" produces FlexAttention.Config in the model spec
Verifies the inspect guard raises RuntimeError on legacy torchtitan
Verifies source-level presence of the inspect guard (regression guard)
All 18 tests pass on A100 with torchtitancommit 7cec166.

torchtitan version requirement
Requires torchtitan at commit 7cec166 or later (merged Apr 2026).
If you are on an older snapshot, the engine will raise RuntimeError at
startup with instructions to upgrade.


---

The branch has 4 commits changing exactly 2 files:
- [verl/workers/engine/torchtitan/transformer_impl.py](verl-upstream-main/verl/workers/engine/torchtitan/transformer_impl.py) — the fix
- `tests/unit/torchtitan/test_attn_backend_sync.py`18 unit tests

The target base for the upstream PR would be `volcengine/verl:main`.---


> Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI)
  - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `vllm_omni`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward`, `fully_async`, `one_step_off`
  - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
  - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s) if possible.

```python
# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

kahlun added 4 commits April 28, 2026 10:18
…attn_type

Bug: TorchtitanEngineConfig.attn_type defaults to 'flex', but
model_registry() defaults to 'sdpa'. verl never passed attn_backend
to model_registry(), so every torchtitan user silently ran sdpa even
when their config said flex or varlen.

Root causes:
1. model_registry(flavor) called without attn_backend parameter
   → always uses the default 'sdpa'
2. The fallback override used model_spec.model.layer (singular) but
   torchtitan configs use .layers (plural list), so the guard always
   failed and the override never executed
3. A second read site (line 599) also used the wrong attribute path
   self.trainer.model_config.layer.attention.attn_backend

Fix:
- Pass attn_backend=engine_config.attn_type to model_registry()
- Read attn_type from self.engine_config.attn_type (authoritative source)
- Remove the broken post-hoc override entirely

Affected models: llama3, qwen3, deepseek_v3, llama4 (all models)
Affected hardware: CUDA, NPU, XPU (all platforms)
Tested with: torchtitan 0.2.2 (HEAD), NVIDIA A100 80GB

Evidence:
- evidence/evidence_attn_backend_bug.log: reproduction proof
- evidence/evidence_test_results.log: 17/17 tests pass
- tests/unit/torchtitan/test_attn_backend_sync.py: unit tests
…oss torchtitan snapshots

torchtitan's model_registry() API changed between commits:
  - Feb 2026 (9810191): model_registry(flavor) — no attn_backend param,
    backend baked per flavor (e.g. 'debugmodel_flex_attn')
  - Apr 2026 (7cec166+, current): model_registry(flavor, attn_backend='sdpa')

Use inspect.signature to detect which API is present, so both snapshots
work correctly. Emit a clear warning on the legacy path so users know
their attn_type setting could not be applied.

Also fix second bug site: read attn_type from self.engine_config.attn_type
instead of the broken self.trainer.model_config.layer.attention.attn_backend
path (singular 'layer' — always failed; correct path is 'layers', plural).

Tests: 18/18 pass, including a mock test for the legacy fallback path.
Four reviewer findings addressed:

HIGH - Test collection crash when no torchtitan models discovered:
  fixture ids= lambda assumed tuple params; used NotSetType sentinel
  on empty params causing TypeError before any test ran.
  Fix: guard lambda with isinstance check so empty-param case is safe.

MEDIUM - Legacy fallback continued with mismatched model+mask backend:
  On older torchtitan (no attn_backend param), code warned and continued.
  But get_attention_masks() only supports flex/varlen; with a model built
  as sdpa this produces silent training corruption.
  Fix: raise RuntimeError for flex/varlen on legacy torchtitan. sdpa
  with legacy API is fine (model+mask both take the sdpa/causal path).

LOW - Wrong flavor suffix in error message:
  Suggested '{flavor}_flex' but legacy flavors use '_flex_attn' / '_varlen_attn'.
  Fix: corrected to '{flavor}_flex_attn' and '{flavor}_varlen_attn'.

LOW - Unused 'patch' import in test file:
  Fix: removed.

Also: moved 'import inspect' to top-level module imports instead of
inside the function body.
Two changes:

1. Fix 'sdpa is safe on legacy' edge case:
   use_remove_padding defaults to True independently of attn_type.
   When use_remove_padding=True, get_attention_masks() is always called,
   and it only supports flex/varlen — sdpa raises TypeError there.
   So sdpa on legacy torchtitan is also unsafe. Remove the conditional
   'flex/varlen only' guard and raise RuntimeError for any attn_type when
   model_registry() lacks the attn_backend parameter.

2. Drop evidence/ run-artifact logs from the PR diff.
   These are not appropriate for an upstream source PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant