diff --git a/flagai/data/tokenizer/uni_tokenizer/tokenizer.py b/flagai/data/tokenizer/uni_tokenizer/tokenizer.py index 4f0c0cde..1aa2fb04 100644 --- a/flagai/data/tokenizer/uni_tokenizer/tokenizer.py +++ b/flagai/data/tokenizer/uni_tokenizer/tokenizer.py @@ -86,74 +86,292 @@ def __init__(self, self.is_glm = self.tokenizer_model_name.lower().startswith('glm') # self.is_clip = self.tokenizer_model_name.startswith('clip') self.num_tokens = self.text_tokenizer.vocab_size - if self.tokenizer_model_name.startswith('cpm'): - special_tokens.append('eod') - if self.tokenizer_model_name.startswith('opt'): - special_tokens.append('bos') - - try: - with open(self.special_tokens_map, encoding='utf8') as file: dct=json.load(file) - sp_tokens = [(k.replace("_token",""),v['content']) for k,v in dct.items()] - except FileNotFoundError: - dct = None - sp_tokens = [] - for tk in special_tokens: - res = self.search_special(tk) - if res: - sp_tokens += [(tk, res)] - self._command_tokens = [CommandToken(e[0], e[1], self.text_tokenizer.convert_token_to_id(e[1])) for e in sp_tokens] - if self.tokenizer_model_name.lower().startswith("glm"): - if self.tokenizer_class == "wp": + + # self._command_tokens = [CommandToken(e[0], e[1], self.text_tokenizer.convert_token_to_id(e[1])) for e in sp_tokens] + if self.tokenizer_class == "wp": + # set command tokens from wordpiece tokenizer values + self.num_command_tokens = 6 + self.num_text_tokens = self.num_tokens - 5 + self.num_type_tokens = 2 + self.token_start_id = None + self.token_end_id = None + self.token_pad_id = None + try: + self._command_tokens = [ + CommandToken( + 'pad', '[PAD]', + self.text_tokenizer.convert_token_to_id('[PAD]')), + CommandToken( + 'cls', '[CLS]', + self.text_tokenizer.convert_token_to_id('[CLS]')), + CommandToken( + 'mask', '[MASK]', + self.text_tokenizer.convert_token_to_id('[MASK]')), + CommandToken( + 'unk', '[UNK]', + self.text_tokenizer.convert_token_to_id('[UNK]')), + CommandToken( + 'sep', '[SEP]', + self.text_tokenizer.convert_token_to_id('[SEP]')), + CommandToken( + 'eos', '[PAD]', + self.text_tokenizer.convert_token_to_id('[PAD]')), + ] + self.token_start_id = self.text_tokenizer.convert_token_to_id( + '[CLS]') + self.token_end_id = self.text_tokenizer.convert_token_to_id( + '[SEP]') + self.token_pad_id = self.text_tokenizer.convert_token_to_id( + '[PAD]') + self.text_tokenizer._token_cls = "[CLS]" self.text_tokenizer._token_sep = "[SEP]" - fix_command_token = False - elif self.tokenizer_class == "sp": - fix_command_token = True + + except KeyError: self._command_tokens = [ - CommandToken('pad', '<|endoftext|>', self.num_tokens), - CommandToken('eos', '<|endoftext|>', self.num_tokens), - CommandToken('sep', '[SEP]', self.num_tokens + 1), - CommandToken('cls', '[CLS]', self.num_tokens + 2), - CommandToken('mask', '[MASK]', self.num_tokens + 3, lstrip=True), - CommandToken('unk', '[UNK]', self.num_tokens + 4) + CommandToken( + 'pad', '[PAD]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'cls', '[CLS]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'MASK', '[MASK]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'unk', '[UNK]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'sep', '[SEP]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'eos', '[PAD]', + self.text_tokenizer.convert_token_to_id('')), ] - self.num_tokens += 6 - elif self.tokenizer_class == "bpe": + self.token_start_id = self.text_tokenizer.convert_token_to_id( + '') + self.token_end_id = self.text_tokenizer.convert_token_to_id( + '') + self.token_pad_id = self.text_tokenizer.convert_token_to_id( + '') + self.text_tokenizer._token_cls = "" + self.text_tokenizer._token_sep = "" + if add_block_symbols: + self.add_command_token('sop', '<|startofpiece|>') + self.add_command_token('eop', '<|endofpiece|>',) + if add_task_mask: + self.add_command_token('gMASK', '[gMASK]') + self.add_command_token('sMASK', '[sMASK]') + if add_decoder_mask: + self.add_command_token('dBLOCK', '[dBLOCK]') + if add_sentinel_token > 0: + for i in range(1, add_sentinel_token): + self.add_command_token(f'MASK{i}', f'[MASK{i}]') + self.add_command_token(f'sop{i}', f'<|startofpiece{i}|>') + elif self.tokenizer_class == "bpe": + if self.tokenizer_model_name.lower().startswith('roberta'): + self.num_command_tokens = 6 + self.num_text_tokens = self.num_tokens - 3 + self._command_tokens = [ + CommandToken( + 'pad', '<|endoftext|>', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'eos', '<|endoftext|>', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'sep', '[SEP]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'cls', '[CLS]', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'mask', + '[MASK]', + self.text_tokenizer.convert_token_to_id(''), + lstrip=True), + CommandToken( + 'unk', '[UNK]', + self.text_tokenizer.convert_token_to_id('')) + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', + self.num_tokens), + CommandToken('eop', '<|endofpiece|>', + self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + self.token_end_id = self.text_tokenizer.convert_token_to_id( + '') + elif self.tokenizer_model_name.lower().startswith('clip'): + self.num_command_tokens = 2 + self._command_tokens = [ + CommandToken( + 'sot', '', + self.text_tokenizer.convert_token_to_id('')), + CommandToken( + 'eot', '', + self.text_tokenizer.convert_token_to_id('')), + ] + self.num_tokens += self.num_command_tokens + self.token_end_id = self.text_tokenizer.convert_token_to_id( + '') + else: + self.num_command_tokens = 2 + self.num_text_tokens = self.num_tokens - 1 self._command_tokens = [ - CommandToken('pad', '<|endoftext|>', - self.text_tokenizer.encoder['<|endoftext|>']), - CommandToken('eos', '<|endoftext|>', - self.text_tokenizer.encoder['<|endoftext|>']) + CommandToken( + 'pad', '<|endoftext|>', + self.text_tokenizer.convert_token_to_id( + '<|endoftext|>')), + CommandToken( + 'eos', '<|endoftext|>', + self.text_tokenizer.convert_token_to_id( + '<|endoftext|>')) ] + self.token_end_id = self.text_tokenizer.convert_token_to_id( + '<|endoftext|>') + if add_block_symbols: + if self.tokenizer_model_name.lower().startswith('glm'): + unk_token_id = self.num_tokens + 5 + cls_token_id = self.num_tokens + 2 + num_tokens_to_add = 5 + else: + unk_token_id = self.text_tokenizer.convert_token_to_id( + '<|endoftext|>') + cls_token_id = self.text_tokenizer.convert_token_to_id( + '<|endoftext|>') + num_tokens_to_add = 4 + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', + self.num_tokens), + CommandToken('eop', '<|endofpiece|>', + self.num_tokens + 1), + CommandToken('cls', '[CLS]', cls_token_id), + CommandToken('MASK', + '[MASK]', + self.num_tokens + 3, + lstrip=True), + CommandToken('sep', '[SEP]', self.num_tokens + 4), + CommandToken('unk', '[UNK]', unk_token_id) + ]) + self.num_tokens += num_tokens_to_add + self.num_command_tokens += 6 + + if add_block_symbols: + if add_task_mask: + self._command_tokens.extend([ + CommandToken('gMASK', + '[gMASK]', + self.num_tokens, + lstrip=True), + CommandToken('sMASK', + '[sMASK]', + self.num_tokens + 1, + lstrip=True) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_decoder_mask: + self._command_tokens.extend( + [CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)]) + self.num_tokens += 1 + self.num_command_tokens += 1 + + elif self.tokenizer_class == "sp": + self.num_command_tokens = 0 + self.num_text_tokens = self.text_tokenizer.vocab_size + self.num_tokens = self.num_text_tokens + + if self.tokenizer_model_name.lower().startswith('glm'): + pad_token_id = self.num_tokens + eos_token_id = self.num_tokens + unk_token_id = self.num_tokens + 4 + else: + pad_token_id = self.text_tokenizer.convert_token_to_id('') + eos_token_id = self.text_tokenizer.convert_token_to_id('') + unk_token_id = self.text_tokenizer.convert_token_to_id('') + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', self.num_text_tokens), + CommandToken('eos', '<|endoftext|>', self.num_text_tokens), + CommandToken('sep', '[SEP]', self.num_text_tokens + 1), + CommandToken('cls', '[CLS]', self.num_text_tokens + 2), + CommandToken('mask', + '[MASK]', + self.num_text_tokens + 3, + lstrip=True), + CommandToken('unk', '[UNK]', self.num_text_tokens + 4) + ] + + self.num_tokens += 5 + self.num_command_tokens += 6 + self.token_end_id = self.text_tokenizer.convert_token_to_id( + '') + if add_block_symbols: + sop_id = self.text_tokenizer.convert_token_to_id('<|startofpiece|>') + eop_id = self.text_tokenizer.convert_token_to_id('<|endofpiece|>') self._command_tokens.extend([ - CommandToken('sop', '<|startofpiece|>', self.num_tokens), - CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1), - CommandToken('cls', '[CLS]', self.num_tokens + 2), - CommandToken('mask', - '[MASK]', - self.num_tokens + 3, - lstrip=True), - CommandToken('sep', '[SEP]', self.num_tokens + 4), - CommandToken('unk', '[UNK]', self.num_tokens + 5) + CommandToken('sop', '<|startofpiece|>', + self.num_tokens + 1), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 2) ]) - self.num_tokens += 6 - if add_block_symbols: - if not self.tokenizer_class == "bpe": - self.add_command_token('sop', '<|startofpiece|>',self.tokenizer_class) - self.add_command_token('eop', '<|endofpiece|>',self.tokenizer_class) + if fix_command_token: + self.num_tokens += 3 + else: + self.num_tokens += 2 + self.num_command_tokens += 2 if add_task_mask: if fix_command_token: - self.add_command_token('sMASK', '[sMASK]',self.tokenizer_class) - self.add_command_token('gMASK', '[gMASK]',self.tokenizer_class) + self._command_tokens.extend([ + CommandToken('sMASK', + '[sMASK]', + self.num_tokens, + lstrip=True), + CommandToken('gMASK', + '[gMASK]', + self.num_tokens + 1, + lstrip=True) + ]) else: - self.add_command_token('gMASK', '[gMASK]',self.tokenizer_class) - self.add_command_token('sMASK', '[sMASK]',self.tokenizer_class) + self._command_tokens.extend([ + CommandToken('gMASK', + '[gMASK]', + self.num_tokens, + lstrip=True), + CommandToken('sMASK', + '[sMASK]', + self.num_tokens + 1, + lstrip=True) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 if add_decoder_mask: - self.add_command_token('dBLOCK', '[dBLOCK]',self.tokenizer_class) - if add_sentinel_token > 0: - for i in range(1, add_sentinel_token): - self.add_command_token(f'MASK{i}', f'[MASK{i}]',self.tokenizer_class) - self.add_command_token(f'sop{i}', f'<|startofpiece{i}|>',self.tokenizer_class) + self._command_tokens.extend( + [CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)]) + self.num_tokens += 1 + self.num_command_tokens += 1 + + if self.tokenizer_model_name.startswith('cpm'): + special_tokens.append('eod') + if self.tokenizer_model_name.startswith('opt'): + special_tokens.append('bos') + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + try: + with open(self.special_tokens_map, encoding='utf8') as file: dct=json.load(file) + sp_tokens = [(k.replace("_token",""),v['content']) for k,v in dct.items()] + for e in sp_tokens: + if e not in self.command_name_map: + self.add_command_token(e[0],e[1],self.tokenizer_class) + except FileNotFoundError: + dct = None + sp_tokens = [] + for tk in special_tokens: + if tk not in self.command_name_map: + res = self.search_special(tk) + self.add_command_token(tk, res,self.tokenizer_class) + self.command_name_map = {tok.name: tok for tok in self._command_tokens} self.command_token_map = { tok.token: tok @@ -166,11 +384,6 @@ def __init__(self, if not self.token_start_id: self.token_start_id = vocab.get('[CLS]', None) - self.token_end_id = vocab.get('', None) - if not self.token_end_id: - self.token_end_id = vocab.get('<|endoftext|>', None) - if not self.token_end_id: - self.token_end_id = vocab.get('[SEP]', None) print("All special tokens: ", str([(k, v.token, v.Id) for k,v in self.command_name_map.items()])) def get_vocab(self): @@ -583,7 +796,7 @@ def search_special(self, name): elif self.check_special('<|endoftext|>'): return '<|endoftext|>' elif name == "eos": if self.check_special(''): return '' - elif self.check_special('|endoftext|'): return '|endoftext|' + elif self.check_special('<|endoftext|>'): return '<|endoftext|>' elif self.check_special('[PAD]'): return '[PAD]' elif name == "sep": if self.check_special(''): return '' @@ -598,7 +811,7 @@ def search_special(self, name): elif self.check_special(''): return '' elif name == "eod": if self.check_special(''): return '' - return None + return '' def check_special(self, tk): diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index c72d34a7..ac8cda5c 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -4,11 +4,13 @@ import unittest from flagai.data.tokenizer import Tokenizer from flagai.auto_model.auto_loader import AutoLoader +import sys;sys.path.append("/home/yanzhaodong/FlagAI") class TokenizerTestCase(unittest.TestCase): def test_tokenizer_GLM_large_ch(self): tokenizer = Tokenizer.from_pretrained("GLM-large-ch") + import pdb;pdb.set_trace() self.assertEqual(tokenizer.TokenToId("人"), 43371, 'Token id "人" error') self.assertEqual(tokenizer.EncodeAsIds("今天吃饭吃了肯德基"), [3378, 1567, 2613, 20282], 'EncodeAsIds Error') @@ -19,12 +21,13 @@ def test_tokenizer_GLM_large_ch(self): self.assertEqual(tokenizer.encode_plus('今天吃饭吃了肯德基')['input_ids'], [50006, 3378, 1567, 2613, 20282, 50001], 'encode_plus Error') self.assertEqual(set([(k, v.token, v.Id) for k,v in tokenizer.command_name_map.items()]), - {('pad', '<|endoftext|>', 50000), ('eos', '<|endoftext|>', 50000), ('sep', '[SEP]', 50001), - ('cls', '[CLS]', 50002), ('mask', '[MASK]', 50003), ('unk', '[UNK]', 50004), ('sop', '<|startofpiece|>', 50006), - ('eop', '<|endofpiece|>', 50007), ('sMASK', '[sMASK]', 50008), ('gMASK', '[gMASK]', 50009)}, 'SpecialTokens error') + {('sop', '<|startofpiece|>', 50006), ('gMASK', '[gMASK]', 50007), ('pad', '<|endoftext|>', 50000), + ('sep', '[SEP]', 50001), ('eos', '<|endoftext|>', 50000), ('unk', '[UNK]', 50004), ('sMASK', '[sMASK]', 50008), + ('mask', '[MASK]', 50003), ('cls', '[CLS]', 50002), ('eop', '<|endofpiece|>', 50007)}, 'SpecialTokens error') def test_tokenizer_GLM_large_en(self): tokenizer = Tokenizer.from_pretrained("GLM-large-en") + import pdb;pdb.set_trace() self.assertEqual(tokenizer.TokenToId("day"), 2154, '') self.assertEqual(tokenizer.EncodeAsIds("fried chicken makes me happy"), [13017, 7975, 3084, 2033, 3407], '') @@ -50,6 +53,7 @@ def test_tokenizer_GLM_large_en(self): def test_tokenizer_t5(self): tokenizer = Tokenizer.from_pretrained('T5-base-ch') + import pdb;pdb.set_trace() self.assertEqual(tokenizer.TokenToId("人"), 297, '') self.assertEqual(tokenizer.EncodeAsIds("今天吃饭吃了肯德基"), [306, 1231, 798, 5447, 798, 266, 4017, 1738, 1166], '') @@ -67,6 +71,7 @@ def test_tokenizer_t5(self): def test_tokenizer_roberta(self): tokenizer = Tokenizer.from_pretrained('RoBERTa-base-ch') + import pdb;pdb.set_trace() self.assertEqual(tokenizer.TokenToId("人"), 782, '') self.assertEqual(tokenizer.EncodeAsIds("今天吃饭吃了肯德基"), [791, 1921, 1391, 7649, 1391, 749, 5507, 2548, 1825], '') @@ -82,6 +87,7 @@ def test_tokenizer_roberta(self): def test_tokenizer_bert(self): tokenizer = Tokenizer.from_pretrained('BERT-base-en') + import pdb;pdb.set_trace() self.assertEqual(tokenizer.TokenToId("day"), 2154, '') self.assertEqual(tokenizer.EncodeAsIds("fried chicken makes me happy"), [13017, 7975, 3084, 2033, 3407], '') @@ -117,6 +123,7 @@ def test_tokenizer_bert(self): def test_tokenizer_opt(self): tokenizer = Tokenizer.from_pretrained('opt-1.3b-en') + import pdb;pdb.set_trace() self.assertEqual(tokenizer.encode("day"), [1208], '') self.assertEqual(tokenizer.encode_plus("fried chicken makes me happy")["input_ids"], [0, 21209, 5884, 817, 162, 1372, 2], '')