Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,22 @@ model = DenseTransformer(cfg)

### Optimization

In the `Trainer` class (available in [`gpt_lab.train.trainer`](./src/gpt_lab/train/trainer.py)), the optimizer is built from `DenseTransformer.build_optimizer` method, which is implemented in the `DenseTransformer` class (available in [`gpt_lab.model.gpt`](./src/gpt_lab/model/gpt.py)). This design allows for a high degree of flexibility and modularity in the optimization process. Moreover, the optimizer is initiated based on [`configs/optim.yaml`](./configs/optim.yaml) configuration file, which can be easily modified to include new optimizers or adjust existing ones.
The optimizer system uses a **registry-based, strategy-pattern architecture** that decouples optimizer logic from execution mode (single-GPU vs. distributed). Optimizers are built and configured by calling `DenseTransformer.build_optimizer` in the `Trainer` class (available in [`gpt_lab.train.trainer`](./src/gpt_lab/train/trainer.py)) using the optimizer configuration from [`configs/optim.yaml`](./configs/optim.yaml), which can specify mixed optimizer groups.

> [!WARNING]
> This is maybe the most critical part of the library, regarding model training, and it is also the part that I have less implemented myself. I used a lot of external repositories for code baseline, and used LLMs back and fourth to enhance it. My goal was to make it work, while being more modular. However, my comprehension of optimization algorithms, coupled with `torch.compile` and distributed training is quite limited. So, I encourage you to check the code in [`gpt_lab.optim.factory`](./src/gpt_lab/optim/factory.py) and the corresponding subfolders for the different optimizers.

#### Optimizer Architecture

The system is organized around three key components:

- **Registry** ([`gpt_lab.optim.registry`](./src/gpt_lab/optim/registry.py)): Defines optimizer specifications with required hyperparameters and validates YAML configs early.
- **Strategy Interface** ([`gpt_lab.optim.strategy`](./src/gpt_lab/optim/strategy.py)): Abstract `OptimizerStrategy` base class with methods for `local_step`, `dist_reduce`, and `dist_compute` (3-phase async pattern).
- **Concrete Strategies** ([`gpt_lab.optim.strategies/`](./src/gpt_lab/optim/strategies/)): Per-optimizer implementations (AdamW, Muon, etc.). No cross-strategy duplication.
- **Factory** ([`gpt_lab.optim.factory`](./src/gpt_lab/optim/factory.py)): Single entry point that automatically selects single-GPU or distributed backend and delegates to the appropriate strategies.

#### Configuration Example

```yaml
default:
opt: "adamw"
Expand All @@ -325,23 +340,79 @@ transformer:
...
```

The optimization process is decoupled from the model architecture, and is implemented as a separate component that can be easily swapped and customized. The optimizer is built based on the model configuration and the training configuration, using a factory pattern. The optimizer implementations are located in [`gpt_lab.optim.factory`](./src/gpt_lab/optim/factory.py) and the corresponding subfolders for the different optimizers.
#### Usage

```python
from gpt_lab.optim import OptimizerFactory

model = ...
optim_cfg = ... # dict of optimizer hyperparameters, e.g., {"opt": "adamw", "lr": 1e-3, ...}
optim_cfg = ... # dict of optimizer hyperparameters per group
param_groups = [
{"params": model.embeddings.parameters(), **optim_cfg["embeddings"]},
{"params": model.blocks.parameters(), **optim_cfg["blocks"]},
{"params": model.blocks.parameters(), **optim_cfg["transformer"]},
...
]
optimizer = OptimizerFactory.build_optimizer(param_groups)
optimizer = OptimizerFactory(param_groups)
```

> [!WARNING]
> This is maybe the most critical part of the library, regarding model training, and it is also the part that I have less implemented myself. I used a lot of external repositories for code baseline, and used LLMs back and fourth to enhance it. My goal was to make it work, while being more modular. However, my comprehension of optimization algorithms, coupled with `torch.compile` and distributed training is quite limited. So, I encourage you to check the code in [`gpt_lab.optim.factory`](./src/gpt_lab/optim/factory.py) and the corresponding subfolders for the different optimizers.
The factory automatically:
- Validates all required keys for each optimizer at construction (fails fast on config errors)
- Detects DDP mode and selects single-GPU or distributed backend
- Manages a reusable scalar cache for torch.compile stability
- Routes parameter groups to the correct strategy at each step

#### Adding a New Optimizer

To add a new optimizer (e.g., Aurora), you only need to edit **one place**:

1. Create strategy file `src/gpt_lab/optim/strategies/aurora.py`:
```python
from gpt_lab.optim.strategy import OptimizerStrategy

class AuroraStrategy(OptimizerStrategy):
def local_step(self, group: dict) -> None:
# Single-GPU update logic
pass

def dist_reduce(self, group: dict, world_size: int) -> dict:
# Phase 1: launch async reductions, return info dict
pass

def dist_compute(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
# Phases 2-3: wait, compute, launch gathers
pass
```

2. Export it from `src/gpt_lab/optim/strategies/__init__.py`:
```python
from gpt_lab.optim.strategies.aurora import AuroraStrategy
__all__ = [..., "AuroraStrategy"]
```

3. Register in `src/gpt_lab/optim/factory.py` (in the "Register built-in optimizers" section):
```python
register_optimizer(OptimizerSpec(
name="aurora",
required_keys=frozenset({"opt", "lr", "momentum", "weight_decay"}),
strategy_class=AuroraStrategy,
))
```

4. Use in YAML:
```yaml
optimizer:
- opt: aurora
lr: 1e-4
momentum: 0.9
weight_decay: 0.0
```

#### Current Implementations

- **AdamW** ([`strategies/adamw.py`](./src/gpt_lab/optim/strategies/adamw.py)): Standard AdamW with ZeRO-2 style sharding for large parameters.
- **Muon** ([`strategies/muon.py`](./src/gpt_lab/optim/strategies/muon.py)): MomentUm Orthogonalized by Newton-Schulz with efficient stacked parameter updates.

Both support single-GPU and multi-GPU distributed training with async communication patterns.

#### Pre training

Expand Down
Loading
Loading