Skip to content

Clarify the hyperparameters for reproducing DiffuLLaMA pretraining? #25

@jwkirchenbauer

Description

@jwkirchenbauer

Hi there, thanks for the great work developing a public codebase that others can use to build upon your work!

I am working on trying to reproduce the "DiffuLLaMA" model from the paper which is a LLaMA2 model adapted from standard AR next token prediction to masked diffusion.

My understanding is that the relevant chart in the paper that summarizes the way that the training dynamics should look is Figure 2, and the description of the setup details is a combination of text in Sec 4.1, and App. B.2 in the paper.

Image

my setup

To start, I simply prepared a smaller amount of webtext data (from Dolma) in the packed format adopted from Tinyllama, and then configured a launching script to run the pretraining job on my cluster. While the dataset contains different webtext components, I am using AMD gpus, and swapped from accelerate's deepspeed zero3 to just pure DDP training, I am hoping that we can ignore some of these details for a moment while discussing the machine learning specific configuration parameters for the job.

Following a combination of the parameters specified in DiffuLLaMA-training/run_distributed.sh and the hparams discussed in Sec 4.1 and App. B2, I arrived at the following command for launching the job. The cluster request is for 16 nodes with 4 gpus each, and due to the switch to simple DDP strategy, I need to reduce the mbsz and up the steps of accumulation to compensate and achieve a similar global batch size.

# 65e9 toks /(16N*4gpn*16mbsz*16accum*2048slen) = 1937.1 steps
MAX_TRAIN_STEPS = 1938 

DATASET_PATH=/dir/with/10b/tokens/of/webtext
MODEL_PATH=/hf_hub/pretrained/Llama-2-7b-hf

# srun -n16 -n64 ...

python -u train.py \
--wandb Diffusion \
--seed 2829 \
--max-train-steps {MAX_TRAIN_STEPS}  \
--learning-rate 2e-5  \
--dataset {DATASET_PATH} \
--model {MODEL_PATH} \
--parallel_mode data_parallel

A few questions:

1. very large global token batch size

While the calculation of the exact step count is slightly different since the mbsz and accum level are different for me, the overall number is still close, about 2k steps of training. Something feels very odd to me re: this setup. According to the paper hparams, the number of tokens in each global batch eg:

The global batch size is calculated by multiplying the single GPU batch size, the number of gradient accumulation steps, and the number of GPUs

as stated in B.2, comes out to 16N * 4gpus * 60mbsz * 4accum * 2048seqlen = 31M tokens per optimization step. The number for my setup would be about 33M. This feels like a really large batch size for LLM training. Now, Im mostly used to pretraining causal, autoregressive LLM's but this is still a surprise, so I just want to confirm that this is accurate for the DIffuLLaMA recipe.

2. slow training step times

Given this configuration, the training step time on this cluster is quite slow. I am seeing about ~100 seconds per step required. Now, according to the paper, for the DIffuLLaMA run, the code

directly use bi-directional attention without attention mask annealing

which means that the attention computational cost could be much higher than I am used to (as flash attn doesnt run as fast for non-causal attn mask), and this combined with the 30M tok batch size (rather than say 2M or 4M for normal LLM pretraining) might come together to make each training step very expensive. But again, this is a bit odd with respect to my intuitions; total 2k steps of training, but each step takes minutes so training to just 65B tokens still takes a long time. I want to confirm again that the DiffuLLaMA recipe does indeed involve these very large batches, but a comparatively small number of optimization steps to reach this 65B token mark.

3. lack of loss curve progress

Finally, given the job configuration described above, I am not seeing the loss go down anywhere near as fast as the Figure 2 curve would suggest it should. Since my trial dataset is actually smaller, I would even expect to see a bit of a faster convergence since the model will repeat data not too far into training, but what I am currently seeing is loss curves that make some progress, but then plateau around 7, rather than sinking down to 3 after just 10B tokens of training like in your figure.

Does anything seem off in terms of the hyperparameters or setup that might make the training progress so much slower? Forgive me if there's some obvious mistake in misunderstanding the code configuration or paper details.

Image

Thanks for taking a bit of time to understand what I am attempting, and and offering any clarifications or suggestions that come to mind!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions