Skip to content

Code simplification #29

@pnugues

Description

@pnugues

I have found the MLP language model a bit difficult to examine when I printed the embs variable. The construction of index tensors would probably be easier to construe than gathering directly the embeddings. I would suggest this code that builds an index tensor first and then extracts and flattens the embeddings:

idx_cat = []
for k in range(self.block_size):
    idx_cat.append(idx)
    idx = torch.roll(idx, 1, 1)
    idx[:, 0] = self.vocab_size  # special <BLANK> token
idx_cat = torch.stack(idx_cat, dim=1).transpose(1, 2)
x = self.wte(idx_cat).flatten(start_dim=-2, end_dim=-1)
logits = self.mlp(x)

instead of

# gather the word embeddings of the previous 3 words
embs = []
for k in range(self.block_size):
    tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
    idx = torch.roll(idx, 1, 1)
    idx[:, 0] = self.vocab_size # special <BLANK> token
    embs.append(tok_emb)

# concat all of the embeddings together and pass through an MLP
x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
logits = self.mlp(x)

In addition, the comment

# gather the word embeddings of the previous 3 words

is not correct as it is block_size

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions