Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions experiments/QNLI/roberta/params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
dev_file: data/preprocessed/QNLI/dev_roberta.jsonl
meta_dir: data/preprocessed/QNLI/
train_file: data/preprocessed/QNLI/train_roberta.jsonl
#dev_file: data/preprocessed/QQPdebug/dev_roberta.jsonl
#meta_dir: data/preprocessed/QQPdebug/
#train_file: data/preprocessed/QQPdebug/train_roberta.jsonl
network: roberta
fix_embeddings: false
use_cuda: true
batch_size: 32
epoches: 10
optimizer: bert-adam
length_limit: 128
learning_rate: 1.0e-5
warmup_proportion: 0.1
model_dir: ../roberta-base-py/
239 changes: 225 additions & 14 deletions lion/common/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import os
import re
import six
import sys
import copy
import json
import spacy
import logging
import regex as re
import unicodedata
import collections
from shutil import copyfile
Expand Down Expand Up @@ -532,10 +536,7 @@ class XLNetTokenizer(Tokenizer):

- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
max_model_input_sizes = {}
vocab_files_names = {}

def __init__(self, vocab_file, max_len=None, do_lower_case=False, remove_space=True, keep_accents=False, **kwargs):
def __init__(self, vocab_file, max_len=None, do_lower_case=False, remove_space=True, keep_accents=False):
super(XLNetTokenizer, self).__init__()
self.max_len = max_len if max_len is not None else int(1e12)

Expand Down Expand Up @@ -683,18 +684,10 @@ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
else:
return self._convert_id_to_token(ids)
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
if skip_special_tokens and index in self.all_special_ids:
continue
if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index])
else:
tokens.append(self._convert_id_to_token(index))
tokens.append(self._convert_id_to_token(index))
return tokens

def _convert_id_to_token(self, index, return_unicode=True):
Expand Down Expand Up @@ -724,7 +717,223 @@ def save_vocabulary(self, save_directory):
return (out_vocab_file,)


class RobertaTokenizer(Tokenizer):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => the encoding and tokenize methods should be called with the
``add_prefix_space`` flag set to ``True``.
Otherwise, this tokenizer's ``encode``, ``decode``, and ``tokenize`` methods will not conserve
the spaces at the beginning of a string: `tokenizer.decode(tokenizer.encode(" Hello")) = "Hello"
"""
def __init__(self, vocab_file, merges_file, max_len=None):
super(RobertaTokenizer, self).__init__()
self.unk_token = "<unk>"
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()}
self.byte_encoder = self.bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))

self.cache = {}

# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

@property
def vocab_size(self):
return len(self.encoder)

def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = self.get_pairs(word)

if not pairs:
return token

while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break

if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = self.get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word

def tokenize(self, text, add_prefix_space=False):
split_tokens = []
bpe_tokens = self._tokenize(text, add_prefix_space)
bpe_token_ids = self.convert_tokens_to_ids(bpe_tokens)
for bpe_token in bpe_token_ids:
split_tokens.append((
bpe_token,
None,
None,
None,
None,
None,
))
# return split_tokens
# Set special option for non-entity tag: '' vs 'O' in spaCy
return Tokens(split_tokens, opts={'non_ent': ''})

def _tokenize(self, text, add_prefix_space=True):
""" Tokenize a string.
return_unicode is used only for py2
"""
if add_prefix_space:
text = ' ' + text

bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))

return bpe_tokens

def convert_tokens_to_ids(self, tokens):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary.
"""
if tokens is None:
return None

if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self._convert_token_to_id(tokens)

ids = []
for token in tokens:
ids.append(self._convert_token_to_id(token))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this XLNET model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids

def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))

def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
Args:
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
"""
if isinstance(ids, int):
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
tokens.append(self._convert_id_to_token(index))
return tokens

def _convert_id_to_token(self, index, return_unicode=True):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index)

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
return text

def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, 'vocab.json')
merge_file = os.path.join(save_directory, 'merges.txt')

with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))

index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1

return vocab_file, merge_file

def bytes_to_unicode(self):
"""
Returns list of utf-8 byte and a mapping to unicode strings.
We specifically avoids mapping to whitespace/control characters the bpe code barfs on.

The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(
range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))

def get_pairs(self, word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs


def get_class(name):
name = str(name).lower()
if name == 'spacy':
return SpacyTokenizer
elif name == 'bert':
Expand All @@ -733,5 +942,7 @@ def get_class(name):
return JiebaTokenizer
elif name == 'xlnet':
return XLNetTokenizer
elif name == 'roberta':
return RobertaTokenizer
else:
raise ValueError("Unspport tokenize algorithm:{}".format(name))
2 changes: 1 addition & 1 deletion lion/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def make_char(char_dict, token, word_length=16):
# Not index words
oriAtoken = ex['Atokens']
oriBtoken = ex['Btokens']
if self.args.network == 'xlnet':
if self.args.network == 'xlnet' or self.args.network == 'roberta':
Atoken = torch.LongTensor(ex['Atokens'])
Btoken = torch.LongTensor(ex['Btokens'])
Achar = torch.zeros(len(ex['Atokens']), 16)
Expand Down
Loading