Skip to content

Fix loss_type="chunked_nll" under DeepSpeed ZeRO-3#5873

Open
qgallouedec wants to merge 2 commits into
mainfrom
fix-chunked-zero3
Open

Fix loss_type="chunked_nll" under DeepSpeed ZeRO-3#5873
qgallouedec wants to merge 2 commits into
mainfrom
fix-chunked-zero3

Conversation

@qgallouedec
Copy link
Copy Markdown
Member

@qgallouedec qgallouedec commented May 27, 2026

The chunked-CE path in SFTTrainer reads lm_head.weight (and lm_head.bias) directly to do its own per-chunk matmul, bypassing the DeepSpeed pre-forward hooks that normally allgather ZeRO-3 partitioned params. On non-owning ranks the weight is a 0-element shard and the matmul crashes.

To see the traceback, see https://github.com/huggingface/trl/actions/runs/26475700871/job/77960573388?pr=5846#step:9:60

This adds a tiny _maybe_gather_lm_head_ctx(w, b) helper that wraps the matmul in deepspeed.zero.GatheredParameters when the param is ZeRO-3-partitioned, and is a no-op otherwise (DDP, FSDP2, single-GPU, or when the param is already AVAILABLE, e.g. tied embeddings where embed_tokens shares the weight).

With use_reentrant=False the checkpoint recompute re-enters the context during backward, so the gather happens on both passes.

Also gathers in the n_valid == 0 fallback branch which reads lm_head_weight.float().sum() for the dummy loss.


Note

Medium Risk
Changes distributed training loss computation for chunked_nll; behavior is scoped to ZeRO-3 gather paths and is covered by new multi-GPU tests.

Overview
Fixes loss_type="chunked_nll" when training with DeepSpeed ZeRO-3: the chunked CE path reads lm_head weight/bias directly, so non-owning ranks could see empty shards and crash on the matmul.

Adds _maybe_gather_lm_head_ctx, which wraps those reads in GatheredParameters only when ZeRO-3 is on and params are not already AVAILABLE (no-op for DDP/FSDP2/single-GPU and tied embeddings). The same gather applies in _chunk and in the n_valid == 0 dummy-loss branch.

Adds distributed test test_sft_chunked_nll (DDP, ZeRO-2/3, FSDP2) that runs SFT with --loss_type chunked_nll.

Reviewed by Cursor Bugbot for commit 70d9081. Bugbot is set up for automated code reviews on this repo. Configure here.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f1be826880

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread trl/trainer/sft_trainer.py Outdated
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit f1be826. Configure here.

Comment thread trl/trainer/sft_trainer.py Outdated
params = [w] if b is None else [w, b]
if all(p.ds_status == ZeroParamStatus.AVAILABLE for p in params):
return contextlib.nullcontext()
return deepspeed.zero.GatheredParameters(params)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this work in backward ?

Copy link
Copy Markdown
Member Author

@qgallouedec qgallouedec May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No gathered weight has to survive forward→backward. The chunk matmul runs inside torch.utils.checkpoint.checkpoint(_chunk, ..., use_reentrant=False)

chunk_loss, chunk_correct, chunk_entropy = torch.utils.checkpoint.checkpoint(
_chunk,
h_chunk,
lm_head_weight,
lm_head_bias,
lbl_chunk,
logit_scale,
final_logit_softcapping,
use_reentrant=False,
)

and _maybe_gather_lm_head_ctx wraps the matmul inside _chunk. With use_reentrant=False the checkpoint recomputes _chunk in backward, so the gather context re-enters at backward time and the weight is in the right state exactly when the chunk's gradient is computed.

The re-entered context is correct in both states:

  • resident (tied embeddings / already engine-gathered) → no-op.
  • partitioned → re-gathers. GatheredParameters releases on exit, so a weight gathered in a forward chunk is NOT_AVAILABLE again by backward

I confirmed on Qwen3-0.6B (ZeRO-3, 2 ranks): the gather context fires in both forward and backward and training runs fine (tied head, so backward hits the AVAILABLE no-op)

And the partitioned path in isolation (bare zero.Init, weight NOT_AVAILABLE) is confirmed as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants