diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index f775b42e070..f4cf169a4ab 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -316,11 +316,8 @@ def dump(obj, file): @contextlib.contextmanager def _no_cache_fields(obj): try: - import transformers as tr - if ( - hasattr(tr, "PreTrainedTokenizerBase") - and isinstance(obj, tr.PreTrainedTokenizerBase) + "PreTrainedTokenizerBase" in [base_class.__name__ for base_class in type(obj).__mro__] and hasattr(obj, "cache") and isinstance(obj.cache, dict) ):