I have found the MLP language model a bit difficult to examine when I printed the embs variable. The construction of index tensors would probably be easier to construe than gathering directly the embeddings. I would suggest this code that builds an index tensor first and then extracts and flattens the embeddings:
idx_cat = []
for k in range(self.block_size):
idx_cat.append(idx)
idx = torch.roll(idx, 1, 1)
idx[:, 0] = self.vocab_size # special <BLANK> token
idx_cat = torch.stack(idx_cat, dim=1).transpose(1, 2)
x = self.wte(idx_cat).flatten(start_dim=-2, end_dim=-1)
logits = self.mlp(x)
instead of
# gather the word embeddings of the previous 3 words
embs = []
for k in range(self.block_size):
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
idx = torch.roll(idx, 1, 1)
idx[:, 0] = self.vocab_size # special <BLANK> token
embs.append(tok_emb)
# concat all of the embeddings together and pass through an MLP
x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
logits = self.mlp(x)
In addition, the comment
# gather the word embeddings of the previous 3 words
is not correct as it is block_size
I have found the MLP language model a bit difficult to examine when I printed the
embsvariable. The construction of index tensors would probably be easier to construe than gathering directly the embeddings. I would suggest this code that builds an index tensor first and then extracts and flattens the embeddings:instead of
In addition, the comment
is not correct as it is
block_size