-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer.py
More file actions
279 lines (250 loc) · 12.4 KB
/
tokenizer.py
File metadata and controls
279 lines (250 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import regex as re
import os
from tqdm import tqdm
class Tokenizer():
def __init__(self, padding_token:str='[PAD]'):
"""
Constructor of "Tokenizer" class that expects padding token.\n
args:\n
\t padding_token: str => Expects Padding token.\n
example:\n
\t "<pad>" or "[PAD]" etc.
"""
self.vocab = { idx:bytes([idx]) for idx in range(256) } # first 256 byte values and [idx] is needed to avoid having completely raw byte values
self.merges = {}
self.spcl_tokens = {} # dict(int, byte)
self.inverse_spcl_tokens = {} # dict(byte, int)
self.padding_token = padding_token
self.pt = re.compile("""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""")
def __len__(self):
return len(self.vocab)
def __get_stats(self, encoded_text:list[int], stats:dict=None, verbose:bool=False)->dict:
"""
Private method to get the counts of each pair.\n
args:\n
\t encoded_text : list[int] => List of integer ids.\n
\t stats : dict => Dictonary of counts.\n
\t verbose : bool => Prints stats information if True. False by default.
"""
stats = {} if stats is None else stats
for pair in zip(encoded_text, encoded_text[1:]):
stats[pair] = stats.get(pair,0)+1
if verbose:
print("*"*120)
print(stats)
print("*"*120)
stats = sorted(stats, key=stats.get, reverse=True)
return stats
def __merge_items(self, token_ids:list[int], pair:tuple[int], new_token:int)->list[int]:
"""
Private method that can be used to merge items.\n
args:\n
\t token_ids : list[int] => List of token ids.\n
\t pair : tuple[int] => A tuple with integer pair that you want ot merge with new token.\n
\t new_token : int => A new tokenn id(int) that we replace a given pair of integers with.
"""
new_token_ids= []
i = 0
while i<len(token_ids):
if i<len(token_ids)-1 and token_ids[i]==pair[0] and token_ids[i+1]==pair[1]:
new_token_ids.append(new_token)
i+=2
else:
new_token_ids.append(token_ids[i])
i+=1
return new_token_ids
def vocabulary(self)-> dict[int, bytes]:
"""
Returns the Vocabulary used by the tokenizer.\n
Returning Format -> dict [ int, bytes ]
"""
return self.vocab # a dictionary with keys as token ids and values as byte(or byte pair) objects.
def train(self, corpus_path:str, target_vocab_size:int = 16000, verbose:bool=False)->None:
"""
Method used to train the tokenizer on a text corpus.\n
args:\n
\t corpus_path : str => Path to the text file.\n
\t target_vocab_size : int => The desired size of your vocabulary.
\t verbose : bool => Prints merging information if True. False by default.
"""
assert target_vocab_size>256
self.num_merges = target_vocab_size - 256
with open(corpus_path, 'r', encoding='utf-8') as f:
text = f.read()
text_chunks = re.findall(self.pt, text)
encoded_text = [list(chunk.encode(encoding='utf-8', errors='replace')) for chunk in text_chunks] # raw bytes
# encoded_text = list(map(int, encoded_text)) # converted to integers
for i in tqdm(range(self.num_merges)):
stats = {}
for chunk in encoded_text:
if len(chunk)>1:
max_stats_pair = max(self.__get_stats(chunk, stats, verbose), key=stats.get)
new_id = 256+i
encoded_text = [self.__merge_items(chunk, max_stats_pair, new_id) for chunk in encoded_text]
self.merges[max_stats_pair] = new_id
if verbose:
print(f"The byte token pair {max_stats_pair} got merged to {new_id}")
if self.merges: # if any merges happen
for pair, idx in self.merges.items():
self.vocab[idx] = self.vocab[pair[0]]+self.vocab[pair[1]] # creating new vocab with merged existing values
self.add_special_tokens([self.padding_token])
def decode(self, token_ids:list, skip_special_tokens:bool=False)->str:
"""
convert list of tokens ids to string.\n
list( token_ids ) -> str
args:\n
\t token_ids : list => List of token ids that we want to decode.
\t skip_special_tokens : bool => Returns text along with special tokens if Tue. False by defauult.
"""
if skip_special_tokens:
byte_tokens = b''.join(self.vocab[idx] for idx in token_ids if idx not in self.spcl_tokens) # creating a byte string without special tokens
else:
byte_tokens = b''.join(self.vocab[idx] for idx in token_ids if self.vocab[idx]) # creating a byte string with special tokens
# errors = 'replace' used to mitigate the error of error of decoding invalid bytes according to utf-8 nicode code standards
text = byte_tokens.decode(encoding="utf-8", errors='replace') # converting byte string to string
return text
def __encode_ordinary(self, text:str, verbose:bool=False)->list[int]:
"""
Private method that can be used to convert string to list of token ids for ordinary tokens.
args:\n
\t text : str => Text that we want to encode to tokens.
\t verbose : bool => Shows the stats if set to True. False by default.
"""
encoded_text = list(text.encode(encoding='utf-8'))
while len(encoded_text)>1: # if len(encoded_text) is ==1 there won't be anything to merge
stats = self.__get_stats(encoded_text,{}, verbose) # get all the stats of token id pairs
pair = min(stats, key=lambda x: self.merges.get(x, float('inf'))) # returns a pair with leaset token_index in merges dict if the given pair is in merges dict otherwise returns a pair that's not in merged
if pair not in self.merges: # if returned pair is not in merges dict
break # no valid merges are left to merge
idx = self.merges[pair]
encoded_text = self.__merge_items(encoded_text, pair, idx)
return encoded_text
def encode(self, text:str, verbose:bool=False)->list[int]:
"""
Takes text as argument and converts into a list of token integers.\n
str -> list( token_ids )\n
args:\n
\t text : str => Text that we want to encode to tokens.
\t verbose : bool => Shows the stats if set to True. False by default.
"""
pattern = self.__split_pattern()
if pattern is not None:
pattern = re.compile(pattern)
chunks = re.split(pattern, text)
ids = []
for chunk in chunks:
if chunk in self.inverse_spcl_tokens:
ids.append(self.inverse_spcl_tokens[chunk])
else:
ids.extend(self.__encode_ordinary(chunk, verbose))
return ids
else:
return self.__encode_ordinary(chunk, verbose)
def __split_pattern(self)->str:
"""
Private method used to construct the pattern to split the text based on special tokens.
"""
pattern = None
if self.inverse_spcl_tokens:
pattern = "("
for i, k in enumerate(self.inverse_spcl_tokens.keys()):
pattern+=re.escape(k)
if i!=len(self.inverse_spcl_tokens.keys())-1:
pattern+="|"
pattern+=")"
return pattern
def save(self, path:str=None)->None:
"""
Takes path as parameter to save the 'Tokenizer' folder containing 'tokenizer.model' and 'vocab.bpe' files.\n
If path is not passed as argument then the current working directory will be used to save tokenizer files by default.\n
tokenizer.model => used to load the trained tekonizer.\n
vocab.bpe => used for human inspection.\n
args:\n
\t path : str => Path to save the 'Tokenizer' folder
"""
path = os.path.dirname(__file__) if path is None else path
path = path+"/Tokenizer/"
if not os.path.exists(path):
os.mkdir(path)
with open(path+"tokenizer.model",'w') as f:
f.write("==Tokenizer==\n")
if self.spcl_tokens:
f.write("**Special Tokens**\n")
f.write(f"{len(self.spcl_tokens)}\n")
for v in self.spcl_tokens.values():
f.write(f"{v}\n")
f.write(f"{len(self.merges)}\n")
for k,v in self.merges.items():
f.write(f"{k[0]},{k[1]}:{v}\n")
with open(path+"vocab.bpe",'w', encoding='utf-8') as f:
for k,v in self.vocab.items():
f.write(f"{k} ====> {v} ====> <[ {v.decode(encoding='utf-8', errors='replace')} ]>\n")
print(f"Successfully saved 'tokenizer.model' and 'vocab.bpe' inside {path} folder!!!")
def load(self, path:str=None)->None:
"""
Takes the path of tokenizer.model as argument when called.\n
If path is not specified, then this method assumes that the 'tokenizer.model' is stored in the 'Tokenizer' folder in the current working directory.\n
args:\n
\t path : str => Path of the 'tokenizer.model' file.
"""
path = path if path is not None else os.path.join(os.path.dirname(__file__), "Tokenizer", "tokenizer.model")
self.vocab = {}
self.merges = {}
special_tokens = []
for i in range(256):
self.vocab[i] = bytes([i])
with open(path,'r', encoding='utf-8') as f:
assert f.readline().strip()=="==Tokenizer=="
print(f"Unpacking {f.readline().strip()}")
num_spcl_tokens = int(f.readline().strip())
for i in range(num_spcl_tokens):
special_tokens.append(f.readline().strip())
num_merges = int(f.readline().strip())
for i in range(num_merges):
item = f.readline().strip().split(":")
k = tuple(map(int,list(item[0].split(","))))
self.merges[k] = int(item[1])
self.vocab[int(item[1])] = self.vocab[k[0]]+self.vocab[k[1]]
self.add_special_tokens(special_tokens=special_tokens)
print("Tokenizer is loaded sucessfully!!!")
def add_special_tokens(self, special_tokens:list[str])->None:
"""
Method that can be used to register new special tokens for the tokenizer to recognize.\n
args:\n
\t special_tokens : list[str] => List of special strings.
example:\n
\t ['[sos]','[eos]']
"""
for i in range(len(special_tokens)):
if special_tokens[i] not in self.spcl_tokens:
self.spcl_tokens[len(self.vocab)] = special_tokens[i]
self.inverse_spcl_tokens[special_tokens[i]] = len(self.vocab)
self.vocab[len(self.vocab)] = special_tokens[i].encode(encoding='utf-8')
def remove_special_tokens(self, special_tokens:list[str])->None:
"""
Method that can be used to delete special tokens.\n
args:\n
\t special_tokens : list[str] => List of special strings.
example:\n
\t ['[sos]','[eos]']
"""
removable_token_ids = [k for k,v in self.spcl_tokens.items() if v in special_tokens]
removable_tokens = [v for k,v in self.spcl_tokens.items() if v in special_tokens]
for i in range(len(removable_token_ids)):
if removable_token_ids[i] in self.spcl_tokens:
del self.spcl_tokens[removable_token_ids[i]]
del self.inverse_spcl_tokens[removable_tokens[i]]
del self.vocab[removable_token_ids[i]]
def display_special_tokens(self)->None:
"""
Prints Special tokens if exists, else prints instructions to add special tokens.
"""
if self.spcl_tokens:
print(f"Special Tokens: {self.spcl_tokens}")
else:
print("No Special Tokens added. \nUse add_special_token() to add a list of special tokens to the tokenizer.")
def get_pad_token_index(self)->int:
"""
Returns padding token index. Used while building Padding Mask.
"""
return self.inverse_spcl_tokens[self.padding_token]