From e4a842557835d198960ecbe26697a1d0d3b4f67b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 13 Aug 2020 17:34:19 +0200 Subject: [PATCH 1/7] fix tokenizers caching --- src/nlp/utils/py_utils.py | 40 ++++++++++-- tests/test_caching.py | 126 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 24 ++++++++ 3 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 tests/test_caching.py diff --git a/src/nlp/utils/py_utils.py b/src/nlp/utils/py_utils.py index 1c0d743633f..2556c4b1f6f 100644 --- a/src/nlp/utils/py_utils.py +++ b/src/nlp/utils/py_utils.py @@ -29,6 +29,7 @@ import dill import numpy as np +import regex # NOTE: When used on an instance method, the cache is shared across all @@ -268,14 +269,29 @@ class Pickler(dill.Pickler): def dump(obj, file): """pickle an object to a file""" - Pickler(file).dump(obj) + Pickler(file, recurse=True).dump(obj) return +@contextlib.contextmanager +def _no_cache_fields(obj): + try: + import transformers as tr + + if isinstance(obj, (tr.CTRLTokenizer, tr.GPT2Tokenizer, tr.OpenAIGPTTokenizer, tr.XLMTokenizer)): + with temporary_assignment(obj, "cache", {}): + yield + else: + yield + except ImportError: + yield + + def dumps(obj): """pickle an object to a string""" file = StringIO() - dump(obj, file) + with _no_cache_fields(obj): + dump(obj, file) return file.getvalue() @@ -288,7 +304,7 @@ def proxy(func): @pklregister(CodeType) -def save_code(pickler, obj): +def _save_code(pickler, obj): """ From dill._dill.save_code This is a modified version that removes the origin (filename + line no.) @@ -297,9 +313,11 @@ def save_code(pickler, obj): dill._dill.log.info("Co: %s" % obj) # Filenames of functions created in notebooks or shells start with '<' # ex: for ipython, and for shell + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True # Only those two lines are different from the original implementation: - co_filename = "" if obj.co_filename.startswith("<") else obj.co_filename - co_firstlineno = 1 if obj.co_filename.startswith("<") else obj.co_firstlineno + co_filename = "" if obj.co_filename.startswith("<") or obj.co_name == "" else obj.co_filename + co_firstlineno = 1 if obj.co_filename.startswith("<") or obj.co_name == "" else obj.co_firstlineno # The rest is the same as in the original dill implementation if dill._dill.PY3: if hasattr(obj, "co_posonlyargcount"): @@ -363,3 +381,15 @@ def save_code(pickler, obj): def copyfunc(func): return types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) + + +@pklregister(type(regex.Regex("", 0))) +def _save_regex(pickler, obj): + dill._dill.log.info("Re: %s" % obj) + args = ( + obj.pattern, + obj.flags, + ) + pickler.save_reduce(regex.compile, args, obj=obj) + dill._dill.log.info("# Re") + return diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 00000000000..3863f5453c1 --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,126 @@ +from hashlib import md5 +from types import CodeType, FunctionType +from unittest import TestCase + +import regex + +import nlp + +from .utils import require_transformers + + +class Foo: + def __init__(self, foo): + self.foo = foo + + def __call__(self): + return self.foo + + +class TokenizersCachingTest(TestCase): + @require_transformers + def test_hash_tokenizer(self): + from transformers import AutoTokenizer + + def encode(x): + return tokenizer(x) + + # TODO: add hash consistency tests across sessions + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + hash1 = md5(nlp.utils.dumps(tokenizer)).hexdigest() + hash1_lambda = md5(nlp.utils.dumps(lambda x: tokenizer(x))).hexdigest() + hash1_encode = md5(nlp.utils.dumps(encode)).hexdigest() + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + hash2 = md5(nlp.utils.dumps(tokenizer)).hexdigest() + hash2_lambda = md5(nlp.utils.dumps(lambda x: tokenizer(x))).hexdigest() + hash2_encode = md5(nlp.utils.dumps(encode)).hexdigest() + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + hash3 = md5(nlp.utils.dumps(tokenizer)).hexdigest() + hash3_lambda = md5(nlp.utils.dumps(lambda x: tokenizer(x))).hexdigest() + hash3_encode = md5(nlp.utils.dumps(encode)).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) + self.assertEqual(hash1_lambda, hash3_lambda) + self.assertNotEqual(hash1_lambda, hash2_lambda) + self.assertEqual(hash1_encode, hash3_encode) + self.assertNotEqual(hash1_encode, hash2_encode) + + @require_transformers + def test_hash_tokenizer_with_cache(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + hash1 = md5(nlp.utils.dumps(tokenizer)).hexdigest() + tokenizer("Hello world !") # call once to change the tokenizer's cache + hash2 = md5(nlp.utils.dumps(tokenizer)).hexdigest() + self.assertEqual(hash1, hash2) + + def test_hash_regex(self): + pat = regex.Regex("foo") + hash1 = md5(nlp.utils.dumps(pat)).hexdigest() + pat = regex.Regex("bar") + hash2 = md5(nlp.utils.dumps(pat)).hexdigest() + pat = regex.Regex("foo") + hash3 = md5(nlp.utils.dumps(pat)).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) + + +class RecurseDumpTest(TestCase): + def test_recurse_dump_for_function(self): + def func(): + return foo + + foo = [0] + hash1 = md5(nlp.utils.dumps(func)).hexdigest() + foo = [1] + hash2 = md5(nlp.utils.dumps(func)).hexdigest() + foo = [0] + hash3 = md5(nlp.utils.dumps(func)).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) + + def test_recurse_dump_for_class(self): + + hash1 = md5(nlp.utils.dumps(Foo([0]))).hexdigest() + hash2 = md5(nlp.utils.dumps(Foo([1]))).hexdigest() + hash3 = md5(nlp.utils.dumps(Foo([0]))).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) + + def test_dump_ipython_function(self): + + code_args = ( + "co_argcount", + "co_kwonlyargcount", + "co_nlocals", + "co_stacksize", + "co_flags", + "co_code", + "co_consts", + "co_names", + "co_varnames", + "co_filename", + "co_name", + "co_firstlineno", + "co_lnotab", + "co_freevars", + "co_cellvars", + ) + + def create_ipython_func(co_filename, returned_obj): + def func(): + return returned_obj + + code = func.__code__ + code = CodeType(*[getattr(code, k) if k != "co_filename" else co_filename for k in code_args]) + return FunctionType(code, func.__globals__, func.__name__, func.__defaults__, func.__closure__) + + co_filename, returned_obj = "", [0] + hash1 = md5(nlp.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + co_filename, returned_obj = "", [1] + hash2 = md5(nlp.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + co_filename, returned_obj = "", [0] + hash3 = md5(nlp.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) diff --git a/tests/utils.py b/tests/utils.py index 0bc8dc7025a..d0c11ebe5cb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,4 @@ +import logging import os import unittest from distutils.util import strtobool @@ -5,6 +6,17 @@ from nlp.utils.file_utils import _tf_available, _torch_available +logger = logging.getLogger(__name__) + +try: + import transformers + + _transformers_available = True # pylint: disable=invalid-name + logger.info("transformers version {} available.".format(transformers.__version__)) +except ImportError: + _transformers_available = False # pylint: disable=invalid-name + + def parse_flag_from_env(key, default=False): try: value = os.environ[key] @@ -50,6 +62,18 @@ def require_tf(test_case): return test_case +def require_transformers(test_case): + """ + Decorator marking a test that requires transformers. + + These tests are skipped when transformers isn't installed. + + """ + if not _transformers_available: + test_case = unittest.skip("test requires transformers")(test_case) + return test_case + + def slow(test_case): """ Decorator marking a test as slow. From a11a1a54e87a3b0dbb8e74893e33038c6af5ba11 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 13 Aug 2020 17:44:02 +0200 Subject: [PATCH 2/7] register regex saver only if regex is available --- src/nlp/utils/py_utils.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/nlp/utils/py_utils.py b/src/nlp/utils/py_utils.py index 2556c4b1f6f..952f462dde1 100644 --- a/src/nlp/utils/py_utils.py +++ b/src/nlp/utils/py_utils.py @@ -29,7 +29,6 @@ import dill import numpy as np -import regex # NOTE: When used on an instance method, the cache is shared across all @@ -383,13 +382,18 @@ def copyfunc(func): return types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) -@pklregister(type(regex.Regex("", 0))) -def _save_regex(pickler, obj): - dill._dill.log.info("Re: %s" % obj) - args = ( - obj.pattern, - obj.flags, - ) - pickler.save_reduce(regex.compile, args, obj=obj) - dill._dill.log.info("# Re") - return +try: + import regex + + @pklregister(type(regex.Regex("", 0))) + def _save_regex(pickler, obj): + dill._dill.log.info("Re: %s" % obj) + args = ( + obj.pattern, + obj.flags, + ) + pickler.save_reduce(regex.compile, args, obj=obj) + dill._dill.log.info("# Re") + return +except ImportError: + pass From 99e9d24d23c938baa06edd5719246c39b80481c8 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 13 Aug 2020 17:44:41 +0200 Subject: [PATCH 3/7] style --- src/nlp/utils/py_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nlp/utils/py_utils.py b/src/nlp/utils/py_utils.py index 952f462dde1..de1e58f5322 100644 --- a/src/nlp/utils/py_utils.py +++ b/src/nlp/utils/py_utils.py @@ -395,5 +395,7 @@ def _save_regex(pickler, obj): pickler.save_reduce(regex.compile, args, obj=obj) dill._dill.log.info("# Re") return + + except ImportError: pass From a4513d7a0e8162b2a1639e30a41bb9190e486bdc Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 13 Aug 2020 17:58:16 +0200 Subject: [PATCH 4/7] add _transformers_available in utils --- src/nlp/utils/file_utils.py | 7 +++++++ src/nlp/utils/py_utils.py | 6 ++++-- tests/utils.py | 13 +------------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/nlp/utils/file_utils.py b/src/nlp/utils/file_utils.py index bb54d59d73f..62746b136bd 100644 --- a/src/nlp/utils/file_utils.py +++ b/src/nlp/utils/file_utils.py @@ -59,6 +59,13 @@ except (ImportError, AssertionError): _tf_available = False # pylint: disable=invalid-name +try: + import transformers + + _transformers_available = True # pylint: disable=invalid-name + logger.info("transformers version {} available.".format(transformers.__version__)) +except ImportError: + _transformers_available = False # pylint: disable=invalid-name hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) diff --git a/src/nlp/utils/py_utils.py b/src/nlp/utils/py_utils.py index de1e58f5322..569cfe78d06 100644 --- a/src/nlp/utils/py_utils.py +++ b/src/nlp/utils/py_utils.py @@ -30,6 +30,8 @@ import dill import numpy as np +from .file_utils import _transformers_available + # NOTE: When used on an instance method, the cache is shared across all # instances and IS NOT per-instance. @@ -274,7 +276,7 @@ def dump(obj, file): @contextlib.contextmanager def _no_cache_fields(obj): - try: + if _transformers_available: import transformers as tr if isinstance(obj, (tr.CTRLTokenizer, tr.GPT2Tokenizer, tr.OpenAIGPTTokenizer, tr.XLMTokenizer)): @@ -282,7 +284,7 @@ def _no_cache_fields(obj): yield else: yield - except ImportError: + else: yield diff --git a/tests/utils.py b/tests/utils.py index d0c11ebe5cb..f07f615e4c4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,18 +3,7 @@ import unittest from distutils.util import strtobool -from nlp.utils.file_utils import _tf_available, _torch_available - - -logger = logging.getLogger(__name__) - -try: - import transformers - - _transformers_available = True # pylint: disable=invalid-name - logger.info("transformers version {} available.".format(transformers.__version__)) -except ImportError: - _transformers_available = False # pylint: disable=invalid-name +from nlp.utils.file_utils import _tf_available, _torch_available, _transformers_available def parse_flag_from_env(key, default=False): From 899dad80d080aa09ff2d2d6cf75e511b355eefaf Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 13 Aug 2020 17:59:45 +0200 Subject: [PATCH 5/7] quality --- tests/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index f07f615e4c4..338dd2e25f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,3 @@ -import logging import os import unittest from distutils.util import strtobool From 71187ccf12667f50d9d0bf98618aa3b95fcf060e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 19 Aug 2020 12:15:45 +0200 Subject: [PATCH 6/7] add etst for methods --- tests/test_caching.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_caching.py b/tests/test_caching.py index 3863f5453c1..3e9bb367064 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -88,6 +88,14 @@ def test_recurse_dump_for_class(self): self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) + def test_recurse_dump_for_method(self): + + hash1 = md5(nlp.utils.dumps(Foo([0]).__call__)).hexdigest() + hash2 = md5(nlp.utils.dumps(Foo([1]).__call__)).hexdigest() + hash3 = md5(nlp.utils.dumps(Foo([0]).__call__)).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) + def test_dump_ipython_function(self): code_args = ( From a7bed869f08e0df2114849e3d17b5372e7a28bea Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 19 Aug 2020 12:16:09 +0200 Subject: [PATCH 7/7] test for tokenizers' cache instead of having a list of tokenisers classes with cache --- src/nlp/utils/py_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nlp/utils/py_utils.py b/src/nlp/utils/py_utils.py index 569cfe78d06..eb439d8745f 100644 --- a/src/nlp/utils/py_utils.py +++ b/src/nlp/utils/py_utils.py @@ -279,7 +279,7 @@ def _no_cache_fields(obj): if _transformers_available: import transformers as tr - if isinstance(obj, (tr.CTRLTokenizer, tr.GPT2Tokenizer, tr.OpenAIGPTTokenizer, tr.XLMTokenizer)): + if isinstance(obj, tr.PreTrainedTokenizerBase) and hasattr(obj, "cache") and isinstance(obj.cache, dict): with temporary_assignment(obj, "cache", {}): yield else: