-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpositional_encoding.py
More file actions
34 lines (27 loc) · 1002 Bytes
/
Copy pathpositional_encoding.py
File metadata and controls
34 lines (27 loc) · 1002 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import math
import torch
from torch import nn
'''
P(i, 2j) = sin(i/10000^(2j/d))
P(i, 2j+1) = cos(i/10000^(2j/d))
'''
class PositionalEncoding(nn.Module):
'''
Positional encoding.
'''
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
if __name__ == '__main__':
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]