Skip to content

[BUG] Significant Performance Difference between DeepSpeed's zero_stage=1 and zero_stage=2 #7697

@liyc-ai

Description

@liyc-ai

Describe the bug
I am using the SFTTrainer of Huggingface's TRL, and I found the training and evaluation exhibit significant difference between zero_stage=1 and zero_stage=2, shown as below:

Image Image

The red lines correspond to zero_stage=1, whereas the black lines correspond to zero_stage=2. Anyone who can help me? Thanks a lot. By the way, I also tried OpenRLHF, where I found the same issue exists for both SFT and DPO.

To Reproduce

To reproduce my issue, please first preprocess the dataset via

import os

import datasets
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm

datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True

output_dir = "data/UltraFeedback"
os.makedirs(output_dir, exist_ok=True)

dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized")

# get train set
prompts, chosen, rejected = [], [], []
for item in tqdm(dataset["train_prefs"]):
    if item["chosen"][1]["content"] == "" or item["rejected"][1]["content"] == "":
        continue

    prompts.append(item["prompt"])
    chosen.append(item["chosen"][1]["content"])
    rejected.append(item["rejected"][1]["content"])

df = pd.DataFrame(
    {
        "prompt": prompts,
        "chosen": chosen,
        "rejected": rejected,
    }
)
df.to_parquet(os.path.join(output_dir, "all_train.parquet"))

# split into sft, rm, rl
split_ratio = [0.2, 0.4, 0.4]
total_samples = len(prompts)

sft_size = int(total_samples * split_ratio[0])
rm_size = int(total_samples * split_ratio[1])
rl_size = total_samples - sft_size - rm_size

random_df = df.sample(frac=1, random_state=42).reset_index(drop=True)
sft_df = random_df[:sft_size]
rm_df = random_df[sft_size : sft_size + rm_size]
rl_df = random_df[sft_size + rm_size :]

print(f"Original: {total_samples}")
print(f"SFT: {len(sft_df)} ({len(sft_df)/total_samples:.1%})")
print(f"RM: {len(rm_df)} ({len(rm_df)/total_samples:.1%})")
print(f"RL: {len(rl_df)} ({len(rl_df)/total_samples:.1%})")

# sft
sft_df.to_parquet(os.path.join(output_dir, "raw_sft.parquet"), index=False)
sft_df = sft_df.drop(columns=["rejected"]).rename(columns={"chosen": "completion"})
sft_df.to_parquet(os.path.join(output_dir, "sft.parquet"), index=False)

# rm
rm_df.to_parquet(os.path.join(output_dir, "rm.parquet"), index=False)

# rl
rl_df.to_parquet(os.path.join(output_dir, "rl.parquet"), index=False)

# get test set
prompts, chosen, rejected = [], [], []
for item in tqdm(dataset["test_prefs"]):
    if item["chosen"][1]["content"] == "" or item["rejected"][1]["content"] == "":
        continue

    prompts.append(item["prompt"])
    chosen.append(item["chosen"][1]["content"])
    rejected.append(item["rejected"][1]["content"])

test_df = pd.DataFrame(
    {
        "prompt": prompts,
        "chosen": chosen,
        "rejected": rejected,
    }
)

# sft
sft_df = test_df.drop(columns=["rejected"]).rename(columns={"chosen": "completion"})
sft_df.to_parquet(os.path.join(output_dir, "test_sft.parquet"), index=False)

# rm
test_df.to_parquet(os.path.join(output_dir, "test_rm.parquet"), index=False)

# rl
test_df.to_parquet(os.path.join(output_dir, "test_rl.parquet"), index=False)

Then, start training via bash train.sh. The content of train.sh is

set -x
umask 000
source .venv/bin/activate

export TORCH_CUDA_ALLOC_CONF=expandable_segments:True
export OMP_NUM_THREADS=1
export PRETRAINED_MODEL_NAME=Qwen/Qwen2.5-1.5B

export N_GPU=$(nvidia-smi --query-gpu=count --format=csv,noheader | head -n 1)
accelerate launch \
    --config_file configs/sft_zero1_4gpu.yaml \
    sft.py \
    model_name=${PRETRAINED_MODEL_NAME} \
    trainer.per_device_train_batch_size=4 \
    trainer.eval_steps=10 \
    trainer.seed=42 \
    trainer.full_determinism=true \
    compute.n_gpus=${N_GPU}

where my sft.py is

import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"

from datetime import datetime

import hydra
from accelerate import PartialState
from accelerate.utils import broadcast_object_list
from datasets import Dataset, load_dataset
from omegaconf import DictConfig, OmegaConf
from transformers import AutoTokenizer, PreTrainedTokenizer

from trl import SFTTrainer
from trl.trainer.sft_config import SFTConfig


def sync_cfg(state: PartialState, cfg: DictConfig, key: str):
    value_list = [getattr(cfg, key)]
    broadcast_object_list(value_list, from_process=0)
    setattr(cfg, key, value_list[0])
    state.wait_for_everyone()


def format_sft_dataset(
    dataset: Dataset, tokenizer: PreTrainedTokenizer, num_proc: int = 16
) -> Dataset:
    def formatter(example: dict) -> dict:
        return {
            "prompt": tokenizer.apply_chat_template(
                (
                    [{"role": "user", "content": example["prompt"]}]
                    if isinstance(example["prompt"], str)
                    else list(example["prompt"])
                ),
                tokenize=False,
                add_generation_prompt=True,
            ),
            "completion": example["completion"] + tokenizer.eos_token,
        }

    return dataset.map(formatter, num_proc=num_proc)


@hydra.main(config_path="configs", config_name="sft.yaml", version_base=None)
def main(cfg: DictConfig):
    state = PartialState()

    # add timestamp to exp_dir
    if state.is_main_process:
        cfg.exp_dir = f"{cfg.exp_dir}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        os.makedirs(cfg.exp_dir, exist_ok=True)

    # sync
    sync_cfg(state, cfg, "exp_dir")
    OmegaConf.resolve(cfg)

    # resolve compute config
    assert (
        cfg.compute.global_batch_size
        % (cfg.trainer.per_device_train_batch_size * cfg.compute.n_gpus)
        == 0
    ), "global_batch_size must be divisible by per_device_train_batch_size * n_gpus"
    cfg.trainer.gradient_accumulation_steps = cfg.compute.global_batch_size // (
        cfg.trainer.per_device_train_batch_size * cfg.compute.n_gpus
    )
    print(f"Gradient accumulation steps: {cfg.trainer.gradient_accumulation_steps}")

    # load dataset
    if cfg.dataset.is_local:
        train_dataset = load_dataset(
            cfg.dataset.train.path.split(".")[-1],
            data_files=cfg.dataset.train.path,
            split=cfg.dataset.train.split,
        )
        eval_dataset = load_dataset(
            cfg.dataset.eval.path.split(".")[-1],
            data_files=cfg.dataset.eval.path,
            split=cfg.dataset.eval.split,
        )
    else:
        train_dataset = load_dataset(
            cfg.dataset.train.name, split=cfg.dataset.train.split
        )
        eval_dataset = load_dataset(cfg.dataset.eval.name, split=cfg.dataset.eval.split)

    if cfg.use_ms:
        from modelscope.utils.hf_util import patch_hub

        patch_hub()

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
    train_dataset = format_sft_dataset(train_dataset, tokenizer)
    eval_dataset = format_sft_dataset(eval_dataset, tokenizer)

    # start training
    if state.is_main_process:
        OmegaConf.save(cfg, os.path.join(cfg.exp_dir, "args.yaml"))
    trainer = SFTTrainer(
        model=cfg.model_name,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=SFTConfig(**OmegaConf.to_container(cfg.trainer, resolve=True)),
    )
    trainer.train()
    trainer.save_model(cfg.trainer.output_dir)


if __name__ == "__main__":
    main()

The two config files are in the same folder, named as configs. Please first make it and then create sft.yaml and sft_zero1_4gpu.yaml sequentially.

# sft.yaml
defaults:
  - override hydra/hydra_logging: disabled  
  - override hydra/job_logging: disabled

hydra:  
  output_subdir: null
  run:
    dir: .

exp_dir: logs/sft

# model
model_name: Qwen/Qwen2.5-1.5B
use_ms: true

# dataset
dataset:
  is_local: true
  train: 
    path: data/UltraFeedback/sft.parquet
    split: train
  eval:
    path: data/UltraFeedback/test_sft.parquet
    split: train

compute:
  global_batch_size: 256
  n_gpus: 4

# sft
trainer:
  output_dir: ${exp_dir}/ckpts
  max_length: 4096
  eval_strategy: steps
  eval_steps: 10
  per_device_train_batch_size: 2
  gradient_accumulation_steps: ???
  num_train_epochs: 3
  gradient_checkpointing: true
  activation_offloading: false
  bf16: true
  use_liger_kernel: true
  packing: false
  seed: 42
  full_determinism: true
  report_to: ["tensorboard"]
  logging_dir: ${exp_dir}/tensorboard
  save_strategy: "no"
# sft_zero1_4gpu.yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: cpu
  zero3_init_flag: false
  zero_stage: 1  # change to 2 if test zero_stage 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.4.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
INFO:root:cc -pthread -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -g -fwrapv -O3 -Wall -fPIC -fPIC -c /tmp/tmppw0n0rnh/test.c -o /tmp/tmppw0n0rnh/test.o
INFO:root:cc -pthread /tmp/tmppw0n0rnh/test.o -L/usr/local/cuda-12.8 -L/usr/local/cuda-12.8/lib64 -lcufile -o /tmp/tmppw0n0rnh/a.out
gds .................... [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.8
 [WARNING]  using untested triton version (3.4.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['.venv/lib/python3.10/site-packages/torch']
torch version .................... 2.8.0+cu128
deepspeed install path ........... ['.venv/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.18.2, unknown, unknown
torch cuda version ............... 12.8
torch hip version ................ None
nvcc version ..................... 12.8
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 256.00 GB

System info (please complete the following information):

  • OS: Ubuntu 24.04
  • 1 machines with x4 A100s
  • Python 3.10

Launcher context

The above example uses Huggingface's accelerate. However, I also tried OpenRLHF, which uses deepspeed launcher.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions