Fix loss_type="chunked_nll" under DeepSpeed ZeRO-3#5873
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f1be826880
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit f1be826. Configure here.
| params = [w] if b is None else [w, b] | ||
| if all(p.ds_status == ZeroParamStatus.AVAILABLE for p in params): | ||
| return contextlib.nullcontext() | ||
| return deepspeed.zero.GatheredParameters(params) |
There was a problem hiding this comment.
how does this work in backward ?
There was a problem hiding this comment.
No gathered weight has to survive forward→backward. The chunk matmul runs inside torch.utils.checkpoint.checkpoint(_chunk, ..., use_reentrant=False)
trl/trl/trainer/sft_trainer.py
Lines 212 to 221 in 70d9081
and
_maybe_gather_lm_head_ctx wraps the matmul inside _chunk. With use_reentrant=False the checkpoint recomputes _chunk in backward, so the gather context re-enters at backward time and the weight is in the right state exactly when the chunk's gradient is computed.
The re-entered context is correct in both states:
- resident (tied embeddings / already engine-gathered) → no-op.
- partitioned → re-gathers.
GatheredParametersreleases on exit, so a weight gathered in a forward chunk isNOT_AVAILABLEagain by backward
I confirmed on Qwen3-0.6B (ZeRO-3, 2 ranks): the gather context fires in both forward and backward and training runs fine (tied head, so backward hits the AVAILABLE no-op)
And the partitioned path in isolation (bare zero.Init, weight NOT_AVAILABLE) is confirmed as well.

The chunked-CE path in
SFTTrainerreadslm_head.weight(andlm_head.bias) directly to do its own per-chunk matmul, bypassing the DeepSpeed pre-forward hooks that normally allgather ZeRO-3 partitioned params. On non-owning ranks the weight is a 0-element shard and the matmul crashes.To see the traceback, see https://github.com/huggingface/trl/actions/runs/26475700871/job/77960573388?pr=5846#step:9:60
This adds a tiny
_maybe_gather_lm_head_ctx(w, b)helper that wraps the matmul indeepspeed.zero.GatheredParameterswhen the param is ZeRO-3-partitioned, and is a no-op otherwise (DDP, FSDP2, single-GPU, or when the param is alreadyAVAILABLE, e.g. tied embeddings whereembed_tokensshares the weight).With
use_reentrant=Falsethe checkpoint recompute re-enters the context during backward, so the gather happens on both passes.Also gathers in the
n_valid == 0fallback branch which readslm_head_weight.float().sum()for the dummy loss.Note
Medium Risk
Changes distributed training loss computation for chunked_nll; behavior is scoped to ZeRO-3 gather paths and is covered by new multi-GPU tests.
Overview
Fixes
loss_type="chunked_nll"when training with DeepSpeed ZeRO-3: the chunked CE path readslm_headweight/bias directly, so non-owning ranks could see empty shards and crash on the matmul.Adds
_maybe_gather_lm_head_ctx, which wraps those reads inGatheredParametersonly when ZeRO-3 is on and params are not alreadyAVAILABLE(no-op for DDP/FSDP2/single-GPU and tied embeddings). The same gather applies in_chunkand in then_valid == 0dummy-loss branch.Adds distributed test
test_sft_chunked_nll(DDP, ZeRO-2/3, FSDP2) that runs SFT with--loss_type chunked_nll.Reviewed by Cursor Bugbot for commit 70d9081. Bugbot is set up for automated code reviews on this repo. Configure here.