Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 276 additions & 63 deletions flagai/data/tokenizer/uni_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,74 +86,292 @@ def __init__(self,
self.is_glm = self.tokenizer_model_name.lower().startswith('glm')
# self.is_clip = self.tokenizer_model_name.startswith('clip')
self.num_tokens = self.text_tokenizer.vocab_size
if self.tokenizer_model_name.startswith('cpm'):
special_tokens.append('eod')
if self.tokenizer_model_name.startswith('opt'):
special_tokens.append('bos')

try:
with open(self.special_tokens_map, encoding='utf8') as file: dct=json.load(file)
sp_tokens = [(k.replace("_token",""),v['content']) for k,v in dct.items()]
except FileNotFoundError:
dct = None
sp_tokens = []
for tk in special_tokens:
res = self.search_special(tk)
if res:
sp_tokens += [(tk, res)]
self._command_tokens = [CommandToken(e[0], e[1], self.text_tokenizer.convert_token_to_id(e[1])) for e in sp_tokens]
if self.tokenizer_model_name.lower().startswith("glm"):
if self.tokenizer_class == "wp":

# self._command_tokens = [CommandToken(e[0], e[1], self.text_tokenizer.convert_token_to_id(e[1])) for e in sp_tokens]
if self.tokenizer_class == "wp":
# set command tokens from wordpiece tokenizer values
self.num_command_tokens = 6
self.num_text_tokens = self.num_tokens - 5
self.num_type_tokens = 2
self.token_start_id = None
self.token_end_id = None
self.token_pad_id = None
try:
self._command_tokens = [
CommandToken(
'pad', '[PAD]',
self.text_tokenizer.convert_token_to_id('[PAD]')),
CommandToken(
'cls', '[CLS]',
self.text_tokenizer.convert_token_to_id('[CLS]')),
CommandToken(
'mask', '[MASK]',
self.text_tokenizer.convert_token_to_id('[MASK]')),
CommandToken(
'unk', '[UNK]',
self.text_tokenizer.convert_token_to_id('[UNK]')),
CommandToken(
'sep', '[SEP]',
self.text_tokenizer.convert_token_to_id('[SEP]')),
CommandToken(
'eos', '[PAD]',
self.text_tokenizer.convert_token_to_id('[PAD]')),
]
self.token_start_id = self.text_tokenizer.convert_token_to_id(
'[CLS]')
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'[SEP]')
self.token_pad_id = self.text_tokenizer.convert_token_to_id(
'[PAD]')

self.text_tokenizer._token_cls = "[CLS]"
self.text_tokenizer._token_sep = "[SEP]"
fix_command_token = False
elif self.tokenizer_class == "sp":
fix_command_token = True

except KeyError:
self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.num_tokens),
CommandToken('eos', '<|endoftext|>', self.num_tokens),
CommandToken('sep', '[SEP]', self.num_tokens + 1),
CommandToken('cls', '[CLS]', self.num_tokens + 2),
CommandToken('mask', '[MASK]', self.num_tokens + 3, lstrip=True),
CommandToken('unk', '[UNK]', self.num_tokens + 4)
CommandToken(
'pad', '[PAD]',
self.text_tokenizer.convert_token_to_id('<pad>')),
CommandToken(
'cls', '[CLS]',
self.text_tokenizer.convert_token_to_id('<s>')),
CommandToken(
'MASK', '[MASK]',
self.text_tokenizer.convert_token_to_id('<mask>')),
CommandToken(
'unk', '[UNK]',
self.text_tokenizer.convert_token_to_id('<unk>')),
CommandToken(
'sep', '[SEP]',
self.text_tokenizer.convert_token_to_id('<sep>')),
CommandToken(
'eos', '[PAD]',
self.text_tokenizer.convert_token_to_id('</s>')),
]
self.num_tokens += 6
elif self.tokenizer_class == "bpe":
self.token_start_id = self.text_tokenizer.convert_token_to_id(
'<s>')
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'</s>')
self.token_pad_id = self.text_tokenizer.convert_token_to_id(
'<pad>')
self.text_tokenizer._token_cls = "<s>"
self.text_tokenizer._token_sep = "</s>"
if add_block_symbols:
self.add_command_token('sop', '<|startofpiece|>')
self.add_command_token('eop', '<|endofpiece|>',)
if add_task_mask:
self.add_command_token('gMASK', '[gMASK]')
self.add_command_token('sMASK', '[sMASK]')
if add_decoder_mask:
self.add_command_token('dBLOCK', '[dBLOCK]')
if add_sentinel_token > 0:
for i in range(1, add_sentinel_token):
self.add_command_token(f'MASK{i}', f'[MASK{i}]')
self.add_command_token(f'sop{i}', f'<|startofpiece{i}|>')
elif self.tokenizer_class == "bpe":
if self.tokenizer_model_name.lower().startswith('roberta'):
self.num_command_tokens = 6
self.num_text_tokens = self.num_tokens - 3
self._command_tokens = [
CommandToken(
'pad', '<|endoftext|>',
self.text_tokenizer.convert_token_to_id('</s>')),
CommandToken(
'eos', '<|endoftext|>',
self.text_tokenizer.convert_token_to_id('</s>')),
CommandToken(
'sep', '[SEP]',
self.text_tokenizer.convert_token_to_id('</s>')),
CommandToken(
'cls', '[CLS]',
self.text_tokenizer.convert_token_to_id('<s>')),
CommandToken(
'mask',
'[MASK]',
self.text_tokenizer.convert_token_to_id('<mask>'),
lstrip=True),
CommandToken(
'unk', '[UNK]',
self.text_tokenizer.convert_token_to_id('<unk>'))
]
if add_block_symbols:
self._command_tokens.extend([
CommandToken('sop', '<|startofpiece|>',
self.num_tokens),
CommandToken('eop', '<|endofpiece|>',
self.num_tokens + 1)
])
self.num_tokens += 2
self.num_command_tokens += 2
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'</s>')
elif self.tokenizer_model_name.lower().startswith('clip'):
self.num_command_tokens = 2
self._command_tokens = [
CommandToken(
'sot', '<start_of_text>',
self.text_tokenizer.convert_token_to_id('</s>')),
CommandToken(
'eot', '<end_of_text>',
self.text_tokenizer.convert_token_to_id('</s>')),
]
self.num_tokens += self.num_command_tokens
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'</s>')
else:
self.num_command_tokens = 2
self.num_text_tokens = self.num_tokens - 1
self._command_tokens = [
CommandToken('pad', '<|endoftext|>',
self.text_tokenizer.encoder['<|endoftext|>']),
CommandToken('eos', '<|endoftext|>',
self.text_tokenizer.encoder['<|endoftext|>'])
CommandToken(
'pad', '<|endoftext|>',
self.text_tokenizer.convert_token_to_id(
'<|endoftext|>')),
CommandToken(
'eos', '<|endoftext|>',
self.text_tokenizer.convert_token_to_id(
'<|endoftext|>'))
]
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'<|endoftext|>')
if add_block_symbols:
if self.tokenizer_model_name.lower().startswith('glm'):
unk_token_id = self.num_tokens + 5
cls_token_id = self.num_tokens + 2
num_tokens_to_add = 5
else:
unk_token_id = self.text_tokenizer.convert_token_to_id(
'<|endoftext|>')
cls_token_id = self.text_tokenizer.convert_token_to_id(
'<|endoftext|>')
num_tokens_to_add = 4
self._command_tokens.extend([
CommandToken('sop', '<|startofpiece|>',
self.num_tokens),
CommandToken('eop', '<|endofpiece|>',
self.num_tokens + 1),
CommandToken('cls', '[CLS]', cls_token_id),
CommandToken('MASK',
'[MASK]',
self.num_tokens + 3,
lstrip=True),
CommandToken('sep', '[SEP]', self.num_tokens + 4),
CommandToken('unk', '[UNK]', unk_token_id)
])
self.num_tokens += num_tokens_to_add
self.num_command_tokens += 6

if add_block_symbols:
if add_task_mask:
self._command_tokens.extend([
CommandToken('gMASK',
'[gMASK]',
self.num_tokens,
lstrip=True),
CommandToken('sMASK',
'[sMASK]',
self.num_tokens + 1,
lstrip=True)
])
self.num_tokens += 2
self.num_command_tokens += 2
if add_decoder_mask:
self._command_tokens.extend(
[CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)])
self.num_tokens += 1
self.num_command_tokens += 1

elif self.tokenizer_class == "sp":
self.num_command_tokens = 0
self.num_text_tokens = self.text_tokenizer.vocab_size
self.num_tokens = self.num_text_tokens

if self.tokenizer_model_name.lower().startswith('glm'):
pad_token_id = self.num_tokens
eos_token_id = self.num_tokens
unk_token_id = self.num_tokens + 4
else:
pad_token_id = self.text_tokenizer.convert_token_to_id('<pad>')
eos_token_id = self.text_tokenizer.convert_token_to_id('</s>')
unk_token_id = self.text_tokenizer.convert_token_to_id('<unk>')
self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.num_text_tokens),
CommandToken('eos', '<|endoftext|>', self.num_text_tokens),
CommandToken('sep', '[SEP]', self.num_text_tokens + 1),
CommandToken('cls', '[CLS]', self.num_text_tokens + 2),
CommandToken('mask',
'[MASK]',
self.num_text_tokens + 3,
lstrip=True),
CommandToken('unk', '[UNK]', self.num_text_tokens + 4)
]

self.num_tokens += 5
self.num_command_tokens += 6
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'</s>')
if add_block_symbols:
sop_id = self.text_tokenizer.convert_token_to_id('<|startofpiece|>')
eop_id = self.text_tokenizer.convert_token_to_id('<|endofpiece|>')
self._command_tokens.extend([
CommandToken('sop', '<|startofpiece|>', self.num_tokens),
CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1),
CommandToken('cls', '[CLS]', self.num_tokens + 2),
CommandToken('mask',
'[MASK]',
self.num_tokens + 3,
lstrip=True),
CommandToken('sep', '[SEP]', self.num_tokens + 4),
CommandToken('unk', '[UNK]', self.num_tokens + 5)
CommandToken('sop', '<|startofpiece|>',
self.num_tokens + 1),
CommandToken('eop', '<|endofpiece|>', self.num_tokens + 2)
])
self.num_tokens += 6
if add_block_symbols:
if not self.tokenizer_class == "bpe":
self.add_command_token('sop', '<|startofpiece|>',self.tokenizer_class)
self.add_command_token('eop', '<|endofpiece|>',self.tokenizer_class)
if fix_command_token:
self.num_tokens += 3
else:
self.num_tokens += 2
self.num_command_tokens += 2
if add_task_mask:
if fix_command_token:
self.add_command_token('sMASK', '[sMASK]',self.tokenizer_class)
self.add_command_token('gMASK', '[gMASK]',self.tokenizer_class)
self._command_tokens.extend([
CommandToken('sMASK',
'[sMASK]',
self.num_tokens,
lstrip=True),
CommandToken('gMASK',
'[gMASK]',
self.num_tokens + 1,
lstrip=True)
])
else:
self.add_command_token('gMASK', '[gMASK]',self.tokenizer_class)
self.add_command_token('sMASK', '[sMASK]',self.tokenizer_class)
self._command_tokens.extend([
CommandToken('gMASK',
'[gMASK]',
self.num_tokens,
lstrip=True),
CommandToken('sMASK',
'[sMASK]',
self.num_tokens + 1,
lstrip=True)
])
self.num_tokens += 2
self.num_command_tokens += 2
if add_decoder_mask:
self.add_command_token('dBLOCK', '[dBLOCK]',self.tokenizer_class)
if add_sentinel_token > 0:
for i in range(1, add_sentinel_token):
self.add_command_token(f'MASK{i}', f'[MASK{i}]',self.tokenizer_class)
self.add_command_token(f'sop{i}', f'<|startofpiece{i}|>',self.tokenizer_class)
self._command_tokens.extend(
[CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)])
self.num_tokens += 1
self.num_command_tokens += 1

if self.tokenizer_model_name.startswith('cpm'):
special_tokens.append('eod')
if self.tokenizer_model_name.startswith('opt'):
special_tokens.append('bos')
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
try:
with open(self.special_tokens_map, encoding='utf8') as file: dct=json.load(file)
sp_tokens = [(k.replace("_token",""),v['content']) for k,v in dct.items()]
for e in sp_tokens:
if e not in self.command_name_map:
self.add_command_token(e[0],e[1],self.tokenizer_class)
except FileNotFoundError:
dct = None
sp_tokens = []
for tk in special_tokens:
if tk not in self.command_name_map:
res = self.search_special(tk)
self.add_command_token(tk, res,self.tokenizer_class)

self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {
tok.token: tok
Expand All @@ -166,11 +384,6 @@ def __init__(self,
if not self.token_start_id:
self.token_start_id = vocab.get('[CLS]', None)

self.token_end_id = vocab.get('</s>', None)
if not self.token_end_id:
self.token_end_id = vocab.get('<|endoftext|>', None)
if not self.token_end_id:
self.token_end_id = vocab.get('[SEP]', None)
print("All special tokens: ", str([(k, v.token, v.Id) for k,v in self.command_name_map.items()]))

def get_vocab(self):
Expand Down Expand Up @@ -583,7 +796,7 @@ def search_special(self, name):
elif self.check_special('<|endoftext|>'): return '<|endoftext|>'
elif name == "eos":
if self.check_special('</s>'): return '</s>'
elif self.check_special('|endoftext|'): return '|endoftext|'
elif self.check_special('<|endoftext|>'): return '<|endoftext|>'
elif self.check_special('[PAD]'): return '[PAD]'
elif name == "sep":
if self.check_special('<sep>'): return '<sep>'
Expand All @@ -598,7 +811,7 @@ def search_special(self, name):
elif self.check_special('<mask>'): return '<mask>'
elif name == "eod":
if self.check_special('<eod>'): return '<eod>'
return None
return '<addnew>'

def check_special(self, tk):

Expand Down
Loading