Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ This is simple, but very memory inefficient. If you want to train SAEs for many
torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
```

The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node.
The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and uses a micro_acc_steps multiplier of 2 for the gradient accumulation calculation. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node.

## TODO

Expand Down
4 changes: 3 additions & 1 deletion sparsify/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ class TrainConfig(Serializable):
"""Number of steps over which to accumulate gradients."""

micro_acc_steps: int = 1
"""Chunk the activations into this number of microbatches for training."""
"""Multiplier for gradient accumulation.
Note: does not currently split data or save memory.
"""

loss_fn: Literal["ce", "fvu", "kl"] = "fvu"
"""Loss function to use for training the sparse coders.
Expand Down