-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathabacus.py
More file actions
65 lines (52 loc) · 2.71 KB
/
abacus.py
File metadata and controls
65 lines (52 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Implementation of abacus embeddings"""
# Example of how to extract digit tokens to pass into constructor
# digit_tokens = tokenizer.convert_tokens_to_ids(['0','1','2','3','4','5','6','7','8','9'])
class Abacus(torch.nn.Module):
"""
Abacus Embeddings, learned emebddings resued for each digit.
Integers must be reversed for this to work correctly.
Transformers Can Do Arithmetic with the Right Embeddings, McLeish et al. (2024)
"""
def __init__(self, digit_tokens, embedding_dim, max_seq_length=1024, max_k=99):
"""
digit_tokens (list): list of the tokens for each of the 10 digits, `digit_tokens = tokenizer.convert_tokens_to_ids(['0','1','2','3','4','5','6','7','8','9'])`
embedding_dim (int): dimension to embed into
max_seq_length (int): maximum number of embeddings that can be trained
max_k (int): maximum k value which we randomly shift by during training
"""
super().__init__()
self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
self.register_buffer("digits", torch.tensor(digit_tokens), persistent=False)
self.max_k = max_k
def helper(self, mask, device):
"""
Converts a binary mask of digit locations into spans of consecutive digits
"""
mask_shape = mask.shape
# Create a shifted version of the mask to detect changes from 0 to 1
shifted_mask = torch.cat([torch.zeros((mask_shape[0], 1), device=device, dtype=mask.dtype), mask[:, :-1]], dim=1)
starts = (shifted_mask != mask) & mask
# Generate IDs for each segment of 1s, processing row-wise
segment_ids = torch.cumsum(starts, dim=1)
# Generate an index array row-wise
index = torch.arange(mask.size(1)).repeat(mask.size(0), 1).to(device)
# Reset index at the start of each segment
reset_index = torch.zeros_like(mask).long()
second_term = index * starts.long()
reset_index = reset_index.scatter_add(1, segment_ids, second_term)
# Calculate positions in segment
positions = index - reset_index.gather(1, segment_ids) + 1
# Ensure only values within 1-segments are non-zero
result = positions * mask
return result
def forward(self, input_ids):
"""
input_ids (tensor): a batch of inputs, each row is a sample
"""
mask = torch.isin(input_ids, self.digits)
output = self.helper(mask, input_ids.device)
k=0
if self.training:
k = random.randint(0, self.max_k)
output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1
return self.embedding(output)