diff --git a/README.md b/README.md index a2916ce..d169a95 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ This is simple, but very memory inefficient. If you want to train SAEs for many torchrun --nproc_per_node gpu -m sparsify meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2 ``` -The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and splits each minibatch into 2 microbatches before feeding them into the SAE encoder, thus saving a lot of memory. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node. +The above command trains an SAE for every _even_ layer of Llama 3 8B, using all available GPUs. It accumulates gradients over 8 minibatches, and uses a micro_acc_steps multiplier of 2 for the gradient accumulation calculation. It also loads the model in 8-bit precision using `bitsandbytes`. This command requires no more than 48GB of memory per GPU on an 8 GPU node. ## TODO diff --git a/sparsify/config.py b/sparsify/config.py index 6854d04..fd5f9d6 100644 --- a/sparsify/config.py +++ b/sparsify/config.py @@ -52,7 +52,9 @@ class TrainConfig(Serializable): """Number of steps over which to accumulate gradients.""" micro_acc_steps: int = 1 - """Chunk the activations into this number of microbatches for training.""" + """Multiplier for gradient accumulation. + Note: does not currently split data or save memory. + """ loss_fn: Literal["ce", "fvu", "kl"] = "fvu" """Loss function to use for training the sparse coders.