Skip to content

How to use Grokfast with FP16 mixed precision training? #10

@peterjc123

Description

@peterjc123

Hi, I'm trying out Grokfast in a LLM scenario. Mixed precision training is a commonly-used technique to save GPU memory usage and speedup training. The following code is an example for FP16 training.

scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)

        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

The question is where should I put grads = gradfilter_ema(model, grads)? I tried to put this between scale and unscale, but it doesn't work, the loss scale just explodes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions