Skip to content

lambdv/smolGPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

smolgpt

a small generative pre-trained transformer model implemented in pytorch. smolgpt

class Transformer(torch.nn.Module):
    """gpt 2 like transformer architecture"""
    def __init__(self, d_model=embed_dim, n_heads=12, dropout=0.1, n_layers=12):
        super().__init__()
        # tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.n_layers = n_layers
    
        # embedding layer
        self.embedding_layer = EmbeddingLayer(vocab_size, d_model, dropout, block_size)
        # transformer blocks
        self.blocks = torch.nn.ModuleList()
        for _ in range(n_layers):
            self.blocks.append(TransformerBlock(d_model, n_heads, dropout))
        # final layer norm
        self.final_layer_norm = LayerNorm(d_model, True)
        # linear projection to vocab
        self.lm_head = torch.nn.Linear(d_model, vocab_size, bias=False)

        self.embedding_layer.token_embedding.weight = self.lm_head.weight 
        
        _init(self)

def linear_projection(token_embedding, Wq, Wk, Wv):
    """learnt weights for turning token embedding into K, Q, V"""
    K = Wk(token_embedding)
    Q = Wq(token_embedding)
    V = Wv(token_embedding)
    return K, Q, V

def mased_self_attention(x, Wq, Wk, Wv, Wo, n_heads, d_head):
    """get weighted sums between all tokens by batching masked attention calculations"""
    batchs, token_sequences, dimension = x.shape
    K, Q, V = linear_projection(x, Wq, Wk, Wv)
    Q = Q.view(batchs, token_sequences, n_heads, d_head).transpose(1, 2)
    K = K.view(batchs, token_sequences, n_heads, d_head).transpose(1, 2)
    V = V.view(batchs, token_sequences, n_heads, d_head).transpose(1, 2)

    attention_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_head)
    mask = torch.tril(torch.ones(token_sequences, token_sequences, device=x.device)).view(1, 1, token_sequences, token_sequences)  
    attention_scores = attention_scores.masked_fill(mask == 0, -1e9)  # apply mask

    weights = torch.softmax(attention_scores, dim=-1)
    weighted_sums = weights @ V
    retun Wo(weighted_sums.transpose(1, 2).contiguous().view(batchs, token_sequences, dimension))

class TransformerBlock(torch.nn.Module):
    """learns liguistic patterns: compute attention and MLP with norms and residual connections"""
    def __init__(self, d_model=embed_dim, n_heads=12, dropout=0.1):
        super().__init__()
        # normalize list of token embeddings
        self.norm_1 = LayerNorm(embed_dim, True)
        self.norm_2 = LayerNorm(embed_dim, True)
        # learn KQV weights for each token embedding
        self.Wq, self.Wk, self.Wv = (
            torch.nn.Linear(embed_dim, embed_dim, bias=False),
            torch.nn.Linear(embed_dim, embed_dim, bias=False),
            torch.nn.Linear(embed_dim, embed_dim, bias=False),
        )
        self.Wo = torch.nn.Linear(d_model, d_model, bias=False)
        # 'room to learn new patterns' 768 -> 3072 -> non linear artivation (GELU) -> 768
        self.MLP = MLP(d_model, dropout)

    def forward(self, x):
        x = self._residual(x, self.norm_1, self._self_attention)
        x = self._residual(x, self.norm_2, self.MLP)
        return x

    def _residual(self, x, normal, sublayer):
        return x + sublayer(normal(x))

    def _self_attention(self, x):
        d_head = self.d_model // self.n_heads
        return mased_self_attention(x, self.Wq, self.Wk, self.Wv, self.Wo, self.n_heads, d_head)

def MLP(d_model=embed_dim, dropout=0.1):
    return torch.nn.Sequential(
        torch.nn.Linear(d_model, d_model * 4, bias=False),  # eg: 768 -> 3072
        torch.nn.GELU(),  # non-linear activation function
        torch.nn.Linear(d_model * 4, d_model, bias=False),  # eg: 3072 -> 768
        torch.nn.Dropout(
            dropout
        ),  # randomly drop out some tokens to prevent overfitting
    )

getting started

  1. install dependencies
python -m venv .venv
pip install -r requirements.txt
./.venv/Scripts/activate
  1. train
python prepare.py shakespeare_char # prepare data
python train.py # train model weights
  1. generate text
python use.py "Once upon a time"

About

a small implementation of gpt 2 architecture in pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages