-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTransformer.py
More file actions
38 lines (28 loc) · 1.16 KB
/
Transformer.py
File metadata and controls
38 lines (28 loc) · 1.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch import nn
import math
from .Encoder import EncoderLayer
from .Decoder import DecoderLayer
from .PositionalEncoding import PositionalEncoding
class Transformer(nn.Module):
def __init__(self, embed_size, num_heads, ff_dim, num_layers, vocab_size):
super(Transformer, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.positional_encoding = PositionalEncoding(embed_size)
self.encoder_layers = nn.ModuleList([EncoderLayer(embed_size, num_heads, ff_dim) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(embed_size, num_heads, ff_dim) for _ in range(num_layers)])
self.fc = nn.Linear(embed_size, vocab_size)
def forward(self, x, y=None):
x = self.embedding(x)
x = self.positional_encoding(x)
if y is None:
y=x
else:
y = self.embedding(y.type(torch.int64)) # Cast y to LongTensor if it's not None
y = self.positional_encoding(y)
for layer in self.encoder_layers:
x = layer(x)
for layer in self.decoder_layers:
y = layer(y, x)
y = self.fc(y)
return y