diff --git a/setup.py b/setup.py index 60e0b15a330..a4ab4df104a 100644 --- a/setup.py +++ b/setup.py @@ -180,7 +180,7 @@ "tensorflow>=2.3,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'", "tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'", "tiktoken", - "torch", + "torch>=2.0.0", "soundfile>=0.12.1", "transformers", "typing-extensions>=4.6.1", # due to conflict between apache-beam and pydantic diff --git a/src/datasets/fingerprint.py b/src/datasets/fingerprint.py index 7d73758a049..7d2f2bdb98a 100644 --- a/src/datasets/fingerprint.py +++ b/src/datasets/fingerprint.py @@ -1,5 +1,4 @@ import inspect -import json import os import random import shutil @@ -10,15 +9,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import pyarrow as pa import xxhash -from .info import DatasetInfo from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH -from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table +from .utils._dill import dumps from .utils.deprecation_utils import deprecated from .utils.logging import get_logger -from .utils.py_utils import asdict, dumps if TYPE_CHECKING: @@ -199,6 +195,7 @@ def cleanup_func(): ################# +@deprecated("Use `copyreg.pickle` to register a custom reducer.") def hashregister(*types): def proxy(func): for t in types: @@ -225,15 +222,13 @@ def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str: return m.hexdigest() @classmethod + @deprecated("Use `Hasher.hash` instead.") def hash_default(cls, value: Any) -> str: - return cls.hash_bytes(dumps(value)) + return cls.hash(value) @classmethod def hash(cls, value: Any) -> str: - if type(value) in cls.dispatch: - return cls.dispatch[type(value)](cls, value) - else: - return cls.hash_default(value) + return cls.hash_bytes(dumps(value)) def update(self, value: Any) -> None: header_for_update = f"=={type(value)}==" @@ -245,28 +240,6 @@ def hexdigest(self) -> str: return self.m.hexdigest() -# Register a new hasher can be useful for two possible reasons: -# 1 - optimize the hashing of large amount of data (e.g. pa.Table) -# 2 - take advantage of a custom serialization method (e.g. DatasetInfo) - - -@hashregister(pa.Table, Table, InMemoryTable, MemoryMappedTable, ConcatenationTable) -def _hash_pa_table(hasher, value): - def _hash_pa_array(value): - if isinstance(value, pa.ChunkedArray): - return hasher.hash_bytes(c.to_string().encode("utf-8") for c in value.chunks) - else: - return hasher.hash_bytes(value.to_string().encode("utf-8")) - - value = "-".join(col + "-" + _hash_pa_array(value[col]) for col in sorted(value.column_names)) - return hasher.hash_bytes(value.encode("utf-8")) - - -@hashregister(DatasetInfo) -def _hash_dataset_info(hasher, value): - return hasher.hash_bytes(json.dumps(asdict(value), sort_keys=True).encode("utf-8")) - - ################# # Fingerprinting ################# diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py new file mode 100644 index 00000000000..237f2d8eab2 --- /dev/null +++ b/src/datasets/utils/_dill.py @@ -0,0 +1,430 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Extends `dill` to support pickling more types and produce more consistent dumps.""" +import os +import sys +from io import BytesIO +from types import CodeType, FunctionType + +import dill +from packaging import version + +from .. import config + + +class Pickler(dill.Pickler): + dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) + + def save(self, obj, save_persistent_id=True): + obj_type = type(obj) + if obj_type not in self.dispatch: + if "regex" in sys.modules: + import regex # type: ignore + + if obj_type is regex.Pattern: + pklregister(obj_type)(_save_regexPattern) + if "spacy" in sys.modules: + import spacy # type: ignore + + if issubclass(obj_type, spacy.Language): + pklregister(obj_type)(_save_spacyLanguage) + if "tiktoken" in sys.modules: + import tiktoken # type: ignore + + if obj_type is tiktoken.Encoding: + pklregister(obj_type)(_save_tiktokenEncoding) + if "torch" in sys.modules: + import torch # type: ignore + + if issubclass(obj_type, torch.Tensor): + pklregister(obj_type)(_save_torchTensor) + + # Unwrap `torch.compile`-ed modules + if issubclass(obj_type, torch.nn.Module): + obj = getattr(obj, "_orig_mod", obj) + if "transformers" in sys.modules: + import transformers # type: ignore + + if issubclass(obj_type, transformers.PreTrainedTokenizerBase): + pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase) + + # Unwrap `torch.compile`-ed functions + if obj_type is FunctionType: + obj = getattr(obj, "_torchdynamo_orig_callable", obj) + dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) + + def _batch_setitems(self, items): + # Ignore the order of keys in a dict + try: + # Faster, but fails for unorderable elements + items = sorted(items) + except Exception: # TypeError, decimal.InvalidOperation, etc. + from datasets.fingerprint import Hasher + + items = sorted(items, key=lambda x: Hasher.hash(x[0])) + dill.Pickler._batch_setitems(self, items) + + def memoize(self, obj): + # Don't memoize strings since two identical strings can have different Python ids + if type(obj) is not str: # noqa: E721 + dill.Pickler.memoize(self, obj) + + +def pklregister(t): + """Register a custom reducer for the type.""" + + def proxy(func): + Pickler.dispatch[t] = func + return func + + return proxy + + +def dump(obj, file): + """Pickle an object to a file.""" + Pickler(file, recurse=True).dump(obj) + + +def dumps(obj): + """Pickle an object to a string.""" + file = BytesIO() + dump(obj, file) + return file.getvalue() + + +if config.DILL_VERSION < version.parse("0.3.6"): + + def log(pickler, msg): + dill._dill.log.info(msg) + +elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: + + def log(pickler, msg): + dill._dill.logger.trace(pickler, msg) + + +@pklregister(set) +def _save_set(pickler, obj): + log(pickler, f"Se: {obj}") + try: + # Faster, but fails for unorderable elements + args = (sorted(obj),) + except Exception: # TypeError, decimal.InvalidOperation, etc. + from datasets.fingerprint import Hasher + + args = (sorted(obj, key=Hasher.hash),) + + pickler.save_reduce(set, args, obj=obj) + log(pickler, "# Se") + + +def _save_regexPattern(pickler, obj): + import regex # type: ignore + + log(pickler, f"Re: {obj}") + args = (obj.pattern, obj.flags) + pickler.save_reduce(regex.compile, args, obj=obj) + log(pickler, "# Re") + + +def _save_tiktokenEncoding(pickler, obj): + import tiktoken # type: ignore + + log(pickler, f"Enc: {obj}") + args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) + pickler.save_reduce(tiktoken.Encoding, args, obj=obj) + log(pickler, "# Enc") + + +def _save_torchTensor(pickler, obj): + import torch # type: ignore + + # `torch.from_numpy` is not picklable in `torch>=1.11.0` + def create_torchTensor(np_array): + return torch.from_numpy(np_array) + + log(pickler, f"To: {obj}") + args = (obj.detach().cpu().numpy(),) + pickler.save_reduce(create_torchTensor, args, obj=obj) + log(pickler, "# To") + + +def _save_spacyLanguage(pickler, obj): + import spacy # type: ignore + + def create_spacyLanguage(config, bytes): + lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"]) + lang_inst = lang_cls.from_config(config) + return lang_inst.from_bytes(bytes) + + log(pickler, f"Sp: {obj}") + args = (obj.config, obj.to_bytes()) + pickler.save_reduce(create_spacyLanguage, args, obj=obj) + log(pickler, "# Sp") + + +def _save_transformersPreTrainedTokenizerBase(pickler, obj): + log(pickler, f"Tok: {obj}") + # Ignore the `cache` attribute + state = obj.__dict__ + if "cache" in state and isinstance(state["cache"], dict): + state["cache"] = {} + pickler.save_reduce(type(obj), (), state=state, obj=obj) + log(pickler, "# Tok") + + +if config.DILL_VERSION < version.parse("0.3.6"): + + @pklregister(CodeType) + def _save_code(pickler, obj): + """ + From dill._dill.save_code + This is a modified version that removes the origin (filename + line no.) + of functions created in notebooks or shells for example. + """ + dill._dill.log.info(f"Co: {obj}") + # The filename of a function is the .py file where it is defined. + # Filenames of functions created in notebooks or shells start with '<' + # ex: for ipython, and for shell + # Filenames of functions created in ipykernel the filename + # look like f"{tempdir}/ipykernel_{id1}/{id2}.py" + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True + # + # For the hashing mechanism we ignore where the function has been defined + # More specifically: + # - we ignore the filename of special functions (filename starts with '<') + # - we always ignore the line number + # - we only use the base name of the file instead of the whole path, + # to be robust in case a script is moved for example. + # + # Only those two lines are different from the original implementation: + co_filename = ( + "" + if obj.co_filename.startswith("<") + or ( + len(obj.co_filename.split(os.path.sep)) > 1 + and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") + ) + or obj.co_name == "" + else os.path.basename(obj.co_filename) + ) + co_firstlineno = 1 + # The rest is the same as in the original dill implementation + if dill._dill.PY3: + if hasattr(obj, "co_posonlyargcount"): + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: + args = ( + obj.co_argcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + pickler.save_reduce(CodeType, args, obj=obj) + dill._dill.log.info("# Co") + return + +elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: + # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1104 + @pklregister(CodeType) + def save_code(pickler, obj): + dill._dill.logger.trace(pickler, "Co: %s", obj) + + ############################################################################################################ + # Modification here for huggingface/datasets + # The filename of a function is the .py file where it is defined. + # Filenames of functions created in notebooks or shells start with '<' + # ex: for ipython, and for shell + # Filenames of functions created in ipykernel the filename + # look like f"{tempdir}/ipykernel_{id1}/{id2}.py" + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True + # + # For the hashing mechanism we ignore where the function has been defined + # More specifically: + # - we ignore the filename of special functions (filename starts with '<') + # - we always ignore the line number + # - we only use the base name of the file instead of the whole path, + # to be robust in case a script is moved for example. + # + # Only those two lines are different from the original implementation: + co_filename = ( + "" + if obj.co_filename.startswith("<") + or ( + len(obj.co_filename.split(os.path.sep)) > 1 + and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") + ) + or obj.co_name == "" + else os.path.basename(obj.co_filename) + ) + co_firstlineno = 1 + # The rest is the same as in the original dill implementation, except for the replacements: + # - obj.co_filename => co_filename + # - obj.co_firstlineno => co_firstlineno + ############################################################################################################ + + if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + obj.co_qualname, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_endlinetable, + obj.co_columntable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_exceptiontable"): # python 3.11 (18 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + obj.co_qualname, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_linetable"): # python 3.10 (16 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_posonlyargcount"): # python 3.8 (16 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: # python 3.7 (15 args) + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + + pickler.save_reduce(dill._dill._create_code, args, obj=obj) + dill._dill.logger.trace(pickler, "# Co") + return diff --git a/src/datasets/utils/filelock.py b/src/datasets/utils/filelock.py index 66c3c97649e..df0728efe64 100644 --- a/src/datasets/utils/filelock.py +++ b/src/datasets/utils/filelock.py @@ -1,6 +1,6 @@ # deprecated, please use the `filelock` package instead -from filelock import ( # noqa: F401 # imported for backward compatibility +from filelock import ( # noqa: F401 # imported for backward compatibility TODO: remove in 3.0.0 BaseFileLock, SoftFileLock, Timeout, @@ -8,4 +8,4 @@ WindowsFileLock, ) -from ._filelock import FileLock # noqa: F401 # imported for backward compatibility +from ._filelock import FileLock # noqa: F401 # imported for backward compatibility. TODO: remove in 3.0.0 diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 4d49c3b5865..b4eef3bce6a 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -28,25 +28,27 @@ import warnings from contextlib import contextmanager from dataclasses import fields, is_dataclass -from io import BytesIO as StringIO from multiprocessing import Manager from queue import Empty from shutil import disk_usage -from types import CodeType, FunctionType from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union from urllib.parse import urlparse -import dill import multiprocess import multiprocess.pool import numpy as np -from packaging import version from tqdm.auto import tqdm from .. import config from ..parallel import parallel_map from . import logging from . import tqdm as hf_tqdm +from ._dill import ( # noqa: F401 # imported for backward compatibility. TODO: remove in 3.0.0 + Pickler, + dump, + dumps, + pklregister, +) try: # pragma: no branch @@ -599,771 +601,6 @@ def get_imports(file_path: str) -> Tuple[str, str, str, str]: return imports -class Pickler(dill.Pickler): - """Same Pickler as the one from dill, but improved for notebooks and shells""" - - dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) - - def save(self, obj, save_persistent_id=True): - # lazy registration of reduction functions - obj_type = type(obj) - if obj_type not in Pickler.dispatch: - if config.DILL_VERSION < version.parse("0.3.6"): - - def dill_log(pickler, msg): - dill._dill.log.info(msg) - - elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: - - def dill_log(pickler, msg): - dill._dill.logger.trace(pickler, msg) - - if (obj_type.__module__, obj_type.__name__) == ("_regex", "Pattern"): - try: - import regex - - @pklregister(obj_type) - def _save_regex(pickler, obj): - dill_log(pickler, f"Re: {obj}") - args = ( - obj.pattern, - obj.flags, - ) - pickler.save_reduce(regex.compile, args, obj=obj) - dill_log(pickler, "# Re") - return - - except ImportError: - pass - elif (obj_type.__module__, obj_type.__name__) == ("torch", "Tensor"): - try: - import torch - - @pklregister(obj_type) - def _save_tensor(pickler, obj): - # `torch.from_numpy` is not picklable in `torch>=1.11.0` - def _create_tensor(np_array): - return torch.from_numpy(np_array) - - dill_log(pickler, f"To: {obj}") - args = (obj.detach().cpu().numpy(),) - pickler.save_reduce(_create_tensor, args, obj=obj) - dill_log(pickler, "# To") - return - - except ImportError: - pass - elif (obj_type.__module__, obj_type.__name__) == ("tiktoken.core", "Encoding"): - try: - import tiktoken - - @pklregister(obj_type) - def _save_encoding(pickler, obj): - dill_log(pickler, f"Enc: {obj}") - args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) - pickler.save_reduce(tiktoken.Encoding, args, obj=obj) - dill_log(pickler, "# Enc") - return - - except ImportError: - pass - elif obj_type.__module__.startswith("spacy.lang") and any( - (cls.__module__, cls.__name__) == ("spacy.language", "Language") for cls in obj_type.__mro__ - ): - try: - import spacy - - @pklregister(obj_type) - def _save_lang(pickler, obj): - def _create_lang(config, bytes_data): - lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"]) - nlp = lang_cls.from_config(config) - return nlp.from_bytes(bytes_data) - - dill_log(pickler, f"Sp: {obj}") - args = (obj.config, obj.to_bytes()) - pickler.save_reduce(_create_lang, args, obj=obj) - dill_log(pickler, "# Sp") - return - - except ImportError: - pass - - dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) - - def memoize(self, obj): - # don't memoize strings since two identical strings can have different python ids - if type(obj) != str: # noqa: E721 - dill.Pickler.memoize(self, obj) - - -def dump(obj, file): - """pickle an object to a file""" - Pickler(file, recurse=True).dump(obj) - return - - -@contextmanager -def _no_cache_fields(obj): - try: - if ( - "PreTrainedTokenizerBase" in [base_class.__name__ for base_class in type(obj).__mro__] - and hasattr(obj, "cache") - and isinstance(obj.cache, dict) - ): - with temporary_assignment(obj, "cache", {}): - yield - else: - yield - - except ImportError: - yield - - -def dumps(obj): - """pickle an object to a string""" - file = StringIO() - with _no_cache_fields(obj): - dump(obj, file) - return file.getvalue() - - -def pklregister(t): - def proxy(func): - Pickler.dispatch[t] = func - return func - - return proxy - - -if config.DILL_VERSION < version.parse("0.3.6"): - - @pklregister(set) - def _save_set(pickler, obj): - dill._dill.log.info(f"Se: {obj}") - from datasets.fingerprint import Hasher - - args = (sorted(obj, key=Hasher.hash),) - pickler.save_reduce(set, args, obj=obj) - dill._dill.log.info("# Se") - -elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: - - @pklregister(set) - def _save_set(pickler, obj): - dill._dill.logger.trace(pickler, "Se: %s", obj) - from datasets.fingerprint import Hasher - - args = (sorted(obj, key=Hasher.hash),) - pickler.save_reduce(set, args, obj=obj) - dill._dill.logger.trace(pickler, "# Se") - - -if config.DILL_VERSION < version.parse("0.3.6"): - - @pklregister(CodeType) - def _save_code(pickler, obj): - """ - From dill._dill.save_code - This is a modified version that removes the origin (filename + line no.) - of functions created in notebooks or shells for example. - """ - dill._dill.log.info(f"Co: {obj}") - # The filename of a function is the .py file where it is defined. - # Filenames of functions created in notebooks or shells start with '<' - # ex: for ipython, and for shell - # Filenames of functions created in ipykernel the filename - # look like f"{tempdir}/ipykernel_{id1}/{id2}.py" - # Moreover lambda functions have a special name: '' - # ex: (lambda x: x).__code__.co_name == "" # True - # - # For the hashing mechanism we ignore where the function has been defined - # More specifically: - # - we ignore the filename of special functions (filename starts with '<') - # - we always ignore the line number - # - we only use the base name of the file instead of the whole path, - # to be robust in case a script is moved for example. - # - # Only those two lines are different from the original implementation: - co_filename = ( - "" - if obj.co_filename.startswith("<") - or ( - len(obj.co_filename.split(os.path.sep)) > 1 - and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") - ) - or obj.co_name == "" - else os.path.basename(obj.co_filename) - ) - co_firstlineno = 1 - # The rest is the same as in the original dill implementation - if dill._dill.PY3: - if hasattr(obj, "co_posonlyargcount"): - args = ( - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, - obj.co_name, - co_firstlineno, - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - else: - args = ( - obj.co_argcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, - obj.co_name, - co_firstlineno, - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - else: - args = ( - obj.co_argcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, - obj.co_name, - co_firstlineno, - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - pickler.save_reduce(CodeType, args, obj=obj) - dill._dill.log.info("# Co") - return - -elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: - # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1104 - @pklregister(CodeType) - def save_code(pickler, obj): - dill._dill.logger.trace(pickler, "Co: %s", obj) - - ############################################################################################################ - # Modification here for huggingface/datasets - # The filename of a function is the .py file where it is defined. - # Filenames of functions created in notebooks or shells start with '<' - # ex: for ipython, and for shell - # Filenames of functions created in ipykernel the filename - # look like f"{tempdir}/ipykernel_{id1}/{id2}.py" - # Moreover lambda functions have a special name: '' - # ex: (lambda x: x).__code__.co_name == "" # True - # - # For the hashing mechanism we ignore where the function has been defined - # More specifically: - # - we ignore the filename of special functions (filename starts with '<') - # - we always ignore the line number - # - we only use the base name of the file instead of the whole path, - # to be robust in case a script is moved for example. - # - # Only those two lines are different from the original implementation: - co_filename = ( - "" - if obj.co_filename.startswith("<") - or ( - len(obj.co_filename.split(os.path.sep)) > 1 - and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") - ) - or obj.co_name == "" - else os.path.basename(obj.co_filename) - ) - co_firstlineno = 1 - # The rest is the same as in the original dill implementation, except for the replacements: - # - obj.co_filename => co_filename - # - obj.co_firstlineno => co_firstlineno - ############################################################################################################ - - if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) - args = ( - obj.co_lnotab, # for < python 3.10 [not counted in args] - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, # Modification for huggingface/datasets ############################################ - obj.co_name, - obj.co_qualname, - co_firstlineno, # Modification for huggingface/datasets ######################################### - obj.co_linetable, - obj.co_endlinetable, - obj.co_columntable, - obj.co_exceptiontable, - obj.co_freevars, - obj.co_cellvars, - ) - elif hasattr(obj, "co_exceptiontable"): # python 3.11 (18 args) - args = ( - obj.co_lnotab, # for < python 3.10 [not counted in args] - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, # Modification for huggingface/datasets ############################################ - obj.co_name, - obj.co_qualname, - co_firstlineno, # Modification for huggingface/datasets ######################################### - obj.co_linetable, - obj.co_exceptiontable, - obj.co_freevars, - obj.co_cellvars, - ) - elif hasattr(obj, "co_linetable"): # python 3.10 (16 args) - args = ( - obj.co_lnotab, # for < python 3.10 [not counted in args] - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, # Modification for huggingface/datasets ############################################ - obj.co_name, - co_firstlineno, # Modification for huggingface/datasets ######################################### - obj.co_linetable, - obj.co_freevars, - obj.co_cellvars, - ) - elif hasattr(obj, "co_posonlyargcount"): # python 3.8 (16 args) - args = ( - obj.co_argcount, - obj.co_posonlyargcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, # Modification for huggingface/datasets ############################################ - obj.co_name, - co_firstlineno, # Modification for huggingface/datasets ######################################### - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - else: # python 3.7 (15 args) - args = ( - obj.co_argcount, - obj.co_kwonlyargcount, - obj.co_nlocals, - obj.co_stacksize, - obj.co_flags, - obj.co_code, - obj.co_consts, - obj.co_names, - obj.co_varnames, - co_filename, # Modification for huggingface/datasets ############################################ - obj.co_name, - co_firstlineno, # Modification for huggingface/datasets ######################################### - obj.co_lnotab, - obj.co_freevars, - obj.co_cellvars, - ) - - pickler.save_reduce(dill._dill._create_code, args, obj=obj) - dill._dill.logger.trace(pickler, "# Co") - return - - -if config.DILL_VERSION < version.parse("0.3.5"): - - @pklregister(FunctionType) - def save_function(pickler, obj): - """ - From dill._dill.save_function - This is a modified version that make globs deterministic since the order of - the keys in the output dictionary of globalvars can change. - """ - if not dill._dill._locate_function(obj): - dill._dill.log.info(f"F1: {obj}") - if getattr(pickler, "_recurse", False): - # recurse to get all globals referred to by obj - globalvars = dill.detect.globalvars - globs = globalvars(obj, recurse=True, builtin=True) - if id(obj) in dill._dill.stack: - globs = obj.__globals__ if dill._dill.PY3 else obj.func_globals - else: - globs = obj.__globals__ if dill._dill.PY3 else obj.func_globals - # globs is a dictionary with keys = var names (str) and values = python objects - # however the dictionary is not always loaded in the same order - # therefore we have to sort the keys to make deterministic. - # This is important to make `dump` deterministic. - # Only this line is different from the original implementation: - globs = dict(sorted(globs.items())) - # The rest is the same as in the original dill implementation - _byref = getattr(pickler, "_byref", None) - _recurse = getattr(pickler, "_recurse", None) - _memo = (id(obj) in dill._dill.stack) and (_recurse is not None) - dill._dill.stack[id(obj)] = len(dill._dill.stack), obj - if dill._dill.PY3: - _super = ("super" in getattr(obj.__code__, "co_names", ())) and (_byref is not None) - if _super: - pickler._byref = True - if _memo: - pickler._recurse = False - fkwdefaults = getattr(obj, "__kwdefaults__", None) - pickler.save_reduce( - dill._dill._create_function, - (obj.__code__, globs, obj.__name__, obj.__defaults__, obj.__closure__, obj.__dict__, fkwdefaults), - obj=obj, - ) - else: - _super = ( - ("super" in getattr(obj.func_code, "co_names", ())) - and (_byref is not None) - and getattr(pickler, "_recurse", False) - ) - if _super: - pickler._byref = True - if _memo: - pickler._recurse = False - pickler.save_reduce( - dill._dill._create_function, - (obj.func_code, globs, obj.func_name, obj.func_defaults, obj.func_closure, obj.__dict__), - obj=obj, - ) - if _super: - pickler._byref = _byref - if _memo: - pickler._recurse = _recurse - if ( - dill._dill.OLDER - and not _byref - and (_super or (not _super and _memo) or (not _super and not _memo and _recurse)) - ): - pickler.clear_memo() - dill._dill.log.info("# F1") - else: - dill._dill.log.info(f"F2: {obj}") - name = getattr(obj, "__qualname__", getattr(obj, "__name__", None)) - dill._dill.StockPickler.save_global(pickler, obj, name=name) - dill._dill.log.info("# F2") - return - -elif config.DILL_VERSION.release[:3] == version.parse("0.3.5").release: # 0.3.5, 0.3.5.1 - # https://github.com/uqfoundation/dill/blob/dill-0.3.5.1/dill/_dill.py - @pklregister(FunctionType) - def save_function(pickler, obj): - if not dill._dill._locate_function(obj, pickler): - dill._dill.log.info("F1: %s" % obj) - _recurse = getattr(pickler, "_recurse", None) - _postproc = getattr(pickler, "_postproc", None) - _main_modified = getattr(pickler, "_main_modified", None) - _original_main = getattr(pickler, "_original_main", dill._dill.__builtin__) # 'None' - postproc_list = [] - if _recurse: - # recurse to get all globals referred to by obj - from dill.detect import globalvars - - globs_copy = globalvars(obj, recurse=True, builtin=True) - - # Add the name of the module to the globs dictionary to prevent - # the duplication of the dictionary. Pickle the unpopulated - # globals dictionary and set the remaining items after the function - # is created to correctly handle recursion. - globs = {"__name__": obj.__module__} - else: - globs_copy = obj.__globals__ if dill._dill.PY3 else obj.func_globals - - # If the globals is the __dict__ from the module being saved as a - # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is _original_main.__dict__: - globs_copy = getattr(pickler, "_main", _original_main).__dict__ - globs = globs_copy - # If the globals is a module __dict__, do not save it in the pickle. - elif ( - globs_copy is not None - and obj.__module__ is not None - and getattr(dill._dill._import_module(obj.__module__, True), "__dict__", None) is globs_copy - ): - globs = globs_copy - else: - globs = {"__name__": obj.__module__} - - # DONE: modified here for huggingface/datasets - # - globs is a dictionary with keys = var names (str) and values = python objects - # - globs_copy is a dictionary with keys = var names (str) and values = ids of the python objects - # however the dictionary is not always loaded in the same order - # therefore we have to sort the keys to make deterministic. - # This is important to make `dump` deterministic. - # Only these line are different from the original implementation: - # START - globs_is_globs_copy = globs is globs_copy - globs = dict(sorted(globs.items())) - if globs_is_globs_copy: - globs_copy = globs - elif globs_copy is not None: - globs_copy = dict(sorted(globs_copy.items())) - # END - - if globs_copy is not None and globs is not globs_copy: - # In the case that the globals are copied, we need to ensure that - # the globals dictionary is updated when all objects in the - # dictionary are already created. - if dill._dill.PY3: - glob_ids = {id(g) for g in globs_copy.values()} - else: - glob_ids = {id(g) for g in globs_copy.itervalues()} - for stack_element in _postproc: - if stack_element in glob_ids: - _postproc[stack_element].append((dill._dill._setitems, (globs, globs_copy))) - break - else: - postproc_list.append((dill._dill._setitems, (globs, globs_copy))) - - if dill._dill.PY3: - closure = obj.__closure__ - state_dict = {} - for fattrname in ("__doc__", "__kwdefaults__", "__annotations__"): - fattr = getattr(obj, fattrname, None) - if fattr is not None: - state_dict[fattrname] = fattr - if obj.__qualname__ != obj.__name__: - state_dict["__qualname__"] = obj.__qualname__ - if "__name__" not in globs or obj.__module__ != globs["__name__"]: - state_dict["__module__"] = obj.__module__ - - state = obj.__dict__ - if type(state) is not dict: # noqa: E721 - state_dict["__dict__"] = state - state = None - if state_dict: - state = state, state_dict - - dill._dill._save_with_postproc( - pickler, - ( - dill._dill._create_function, - (obj.__code__, globs, obj.__name__, obj.__defaults__, closure), - state, - ), - obj=obj, - postproc_list=postproc_list, - ) - else: - closure = obj.func_closure - if obj.__doc__ is not None: - postproc_list.append((setattr, (obj, "__doc__", obj.__doc__))) - if "__name__" not in globs or obj.__module__ != globs["__name__"]: - postproc_list.append((setattr, (obj, "__module__", obj.__module__))) - if obj.__dict__: - postproc_list.append((setattr, (obj, "__dict__", obj.__dict__))) - - dill._dill._save_with_postproc( - pickler, - (dill._dill._create_function, (obj.func_code, globs, obj.func_name, obj.func_defaults, closure)), - obj=obj, - postproc_list=postproc_list, - ) - - # Lift closure cell update to earliest function (#458) - if _postproc: - topmost_postproc = next(iter(_postproc.values()), None) - if closure and topmost_postproc: - for cell in closure: - possible_postproc = (setattr, (cell, "cell_contents", obj)) - try: - topmost_postproc.remove(possible_postproc) - except ValueError: - continue - - # Change the value of the cell - pickler.save_reduce(*possible_postproc) - # pop None created by calling preprocessing step off stack - if dill._dill.PY3: - pickler.write(bytes("0", "UTF-8")) - else: - pickler.write("0") - - dill._dill.log.info("# F1") - else: - dill._dill.log.info("F2: %s" % obj) - name = getattr(obj, "__qualname__", getattr(obj, "__name__", None)) - dill._dill.StockPickler.save_global(pickler, obj, name=name) - dill._dill.log.info("# F2") - return - -elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: - # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1739 - @pklregister(FunctionType) - def save_function(pickler, obj): - if not dill._dill._locate_function(obj, pickler): - if type(obj.__code__) is not CodeType: - # Some PyPy builtin functions have no module name, and thus are not - # able to be located - module_name = getattr(obj, "__module__", None) - if module_name is None: - module_name = dill._dill.__builtin__.__name__ - module = dill._dill._import_module(module_name, safe=True) - _pypy_builtin = False - try: - found, _ = dill._dill._getattribute(module, obj.__qualname__) - if getattr(found, "__func__", None) is obj: - _pypy_builtin = True - except AttributeError: - pass - - if _pypy_builtin: - dill._dill.logger.trace(pickler, "F3: %s", obj) - pickler.save_reduce(getattr, (found, "__func__"), obj=obj) - dill._dill.logger.trace(pickler, "# F3") - return - - dill._dill.logger.trace(pickler, "F1: %s", obj) - _recurse = getattr(pickler, "_recurse", None) - _postproc = getattr(pickler, "_postproc", None) - _main_modified = getattr(pickler, "_main_modified", None) - _original_main = getattr(pickler, "_original_main", dill._dill.__builtin__) # 'None' - postproc_list = [] - if _recurse: - # recurse to get all globals referred to by obj - from dill.detect import globalvars - - globs_copy = globalvars(obj, recurse=True, builtin=True) - - # Add the name of the module to the globs dictionary to prevent - # the duplication of the dictionary. Pickle the unpopulated - # globals dictionary and set the remaining items after the function - # is created to correctly handle recursion. - globs = {"__name__": obj.__module__} - else: - globs_copy = obj.__globals__ - - # If the globals is the __dict__ from the module being saved as a - # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is _original_main.__dict__: - globs_copy = getattr(pickler, "_main", _original_main).__dict__ - globs = globs_copy - # If the globals is a module __dict__, do not save it in the pickle. - elif ( - globs_copy is not None - and obj.__module__ is not None - and getattr(dill._dill._import_module(obj.__module__, True), "__dict__", None) is globs_copy - ): - globs = globs_copy - else: - globs = {"__name__": obj.__module__} - - ######################################################################################################## - # Modification here for huggingface/datasets - # - globs is a dictionary with keys = var names (str) and values = python objects - # - globs_copy is a dictionary with keys = var names (str) and values = ids of the python objects - # However the dictionary is not always loaded in the same order, - # therefore we have to sort the keys to make deterministic. - # This is important to make `dump` deterministic. - # Only these line are different from the original implementation: - # START - globs_is_globs_copy = globs is globs_copy - globs = dict(sorted(globs.items())) - if globs_is_globs_copy: - globs_copy = globs - elif globs_copy is not None: - globs_copy = dict(sorted(globs_copy.items())) - # END - ######################################################################################################## - - if globs_copy is not None and globs is not globs_copy: - # In the case that the globals are copied, we need to ensure that - # the globals dictionary is updated when all objects in the - # dictionary are already created. - glob_ids = {id(g) for g in globs_copy.values()} - for stack_element in _postproc: - if stack_element in glob_ids: - _postproc[stack_element].append((dill._dill._setitems, (globs, globs_copy))) - break - else: - postproc_list.append((dill._dill._setitems, (globs, globs_copy))) - - closure = obj.__closure__ - state_dict = {} - for fattrname in ("__doc__", "__kwdefaults__", "__annotations__"): - fattr = getattr(obj, fattrname, None) - if fattr is not None: - state_dict[fattrname] = fattr - if obj.__qualname__ != obj.__name__: - state_dict["__qualname__"] = obj.__qualname__ - if "__name__" not in globs or obj.__module__ != globs["__name__"]: - state_dict["__module__"] = obj.__module__ - - state = obj.__dict__ - if type(state) is not dict: # noqa: E721 - state_dict["__dict__"] = state - state = None - if state_dict: - state = state, state_dict - - dill._dill._save_with_postproc( - pickler, - (dill._dill._create_function, (obj.__code__, globs, obj.__name__, obj.__defaults__, closure), state), - obj=obj, - postproc_list=postproc_list, - ) - - # Lift closure cell update to earliest function (#458) - if _postproc: - topmost_postproc = next(iter(_postproc.values()), None) - if closure and topmost_postproc: - for cell in closure: - possible_postproc = (setattr, (cell, "cell_contents", obj)) - try: - topmost_postproc.remove(possible_postproc) - except ValueError: - continue - - # Change the value of the cell - pickler.save_reduce(*possible_postproc) - # pop None created by calling preprocessing step off stack - pickler.write(bytes("0", "UTF-8")) - - dill._dill.logger.trace(pickler, "# F1") - else: - dill._dill.logger.trace(pickler, "F2: %s", obj) - name = getattr(obj, "__qualname__", getattr(obj, "__name__", None)) - dill._dill.StockPickler.save_global(pickler, obj, name=name) - dill._dill.logger.trace(pickler, "# F2") - return - - def copyfunc(func): result = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) result.__kwdefaults__ = func.__kwdefaults__ diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index f4d5d65744e..47a400fa0c3 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -3,7 +3,6 @@ import pickle import subprocess from functools import partial -from hashlib import md5 from pathlib import Path from tempfile import gettempdir from textwrap import dedent @@ -16,10 +15,12 @@ from multiprocess import Pool import datasets +from datasets import config from datasets.fingerprint import Hasher, fingerprint_transform from datasets.table import InMemoryTable from .utils import ( + require_not_windows, require_regex, require_spacy, require_spacy_model, @@ -59,7 +60,25 @@ def __getstate__(self): raise pickle.PicklingError() -class TokenizersDumpTest(TestCase): +if config.TORCH_AVAILABLE: + import torch + import torch.nn as nn + import torch.nn.functional as F + + class TorchModule(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) +else: + TorchModule = None + + +class TokenizersHashTest(TestCase): @require_transformers @pytest.mark.integration def test_hash_tokenizer(self): @@ -70,17 +89,17 @@ def encode(x): # TODO: add hash consistency tests across sessions tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - hash1 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() - hash1_lambda = md5(datasets.utils.py_utils.dumps(lambda x: tokenizer(x))).hexdigest() - hash1_encode = md5(datasets.utils.py_utils.dumps(encode)).hexdigest() + hash1 = Hasher.hash(tokenizer) + hash1_lambda = Hasher.hash(lambda x: tokenizer(x)) + hash1_encode = Hasher.hash(encode) tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") - hash2 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() - hash2_lambda = md5(datasets.utils.py_utils.dumps(lambda x: tokenizer(x))).hexdigest() - hash2_encode = md5(datasets.utils.py_utils.dumps(encode)).hexdigest() + hash2 = Hasher.hash(tokenizer) + hash2_lambda = Hasher.hash(lambda x: tokenizer(x)) + hash2_encode = Hasher.hash(encode) tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - hash3 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() - hash3_lambda = md5(datasets.utils.py_utils.dumps(lambda x: tokenizer(x))).hexdigest() - hash3_encode = md5(datasets.utils.py_utils.dumps(encode)).hexdigest() + hash3 = Hasher.hash(tokenizer) + hash3_lambda = Hasher.hash(lambda x: tokenizer(x)) + hash3_encode = Hasher.hash(encode) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) self.assertEqual(hash1_lambda, hash3_lambda) @@ -94,9 +113,9 @@ def test_hash_tokenizer_with_cache(self): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") - hash1 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() + hash1 = Hasher.hash(tokenizer) tokenizer("Hello world !") # call once to change the tokenizer's cache - hash2 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() + hash2 = Hasher.hash(tokenizer) self.assertEqual(hash1, hash2) @require_regex @@ -104,56 +123,56 @@ def test_hash_regex(self): import regex pat = regex.Regex("foo") - hash1 = md5(datasets.utils.py_utils.dumps(pat)).hexdigest() + hash1 = Hasher.hash(pat) pat = regex.Regex("bar") - hash2 = md5(datasets.utils.py_utils.dumps(pat)).hexdigest() + hash2 = Hasher.hash(pat) pat = regex.Regex("foo") - hash3 = md5(datasets.utils.py_utils.dumps(pat)).hexdigest() + hash3 = Hasher.hash(pat) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) -class RecurseDumpTest(TestCase): - def test_recurse_dump_for_function(self): +class RecurseHashTest(TestCase): + def test_recurse_hash_for_function(self): def func(): return foo foo = [0] - hash1 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash1 = Hasher.hash(func) foo = [1] - hash2 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash2 = Hasher.hash(func) foo = [0] - hash3 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash3 = Hasher.hash(func) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) - def test_dump_ignores_line_definition_of_function(self): + def test_hash_ignores_line_definition_of_function(self): def func(): pass - hash1 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash1 = Hasher.hash(func) def func(): pass - hash2 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash2 = Hasher.hash(func) self.assertEqual(hash1, hash2) - def test_recurse_dump_for_class(self): - hash1 = md5(datasets.utils.py_utils.dumps(Foo([0]))).hexdigest() - hash2 = md5(datasets.utils.py_utils.dumps(Foo([1]))).hexdigest() - hash3 = md5(datasets.utils.py_utils.dumps(Foo([0]))).hexdigest() + def test_recurse_hash_for_class(self): + hash1 = Hasher.hash(Foo([0])) + hash2 = Hasher.hash(Foo([1])) + hash3 = Hasher.hash(Foo([0])) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) - def test_recurse_dump_for_method(self): - hash1 = md5(datasets.utils.py_utils.dumps(Foo([0]).__call__)).hexdigest() - hash2 = md5(datasets.utils.py_utils.dumps(Foo([1]).__call__)).hexdigest() - hash3 = md5(datasets.utils.py_utils.dumps(Foo([0]).__call__)).hexdigest() + def test_recurse_hash_for_method(self): + hash1 = Hasher.hash(Foo([0]).__call__) + hash2 = Hasher.hash(Foo([1]).__call__) + hash3 = Hasher.hash(Foo([0]).__call__) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) - def test_dump_ipython_function(self): + def test_hash_ipython_function(self): def create_ipython_func(co_filename, returned_obj): def func(): return returned_obj @@ -164,24 +183,24 @@ def func(): return FunctionType(code, func.__globals__, func.__name__, func.__defaults__, func.__closure__) co_filename, returned_obj = "", [0] - hash1 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash1 = Hasher.hash(create_ipython_func(co_filename, returned_obj)) co_filename, returned_obj = "", [1] - hash2 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash2 = Hasher.hash(create_ipython_func(co_filename, returned_obj)) co_filename, returned_obj = "", [0] - hash3 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash3 = Hasher.hash(create_ipython_func(co_filename, returned_obj)) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) co_filename, returned_obj = os.path.join(gettempdir(), "ipykernel_12345", "321456789.py"), [0] - hash4 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash4 = Hasher.hash(create_ipython_func(co_filename, returned_obj)) co_filename, returned_obj = os.path.join(gettempdir(), "ipykernel_12345", "321456789.py"), [1] - hash5 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash5 = Hasher.hash(create_ipython_func(co_filename, returned_obj)) co_filename, returned_obj = os.path.join(gettempdir(), "ipykernel_12345", "654123987.py"), [0] - hash6 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash6 = Hasher.hash(create_ipython_func(co_filename, returned_obj)) self.assertEqual(hash4, hash6) self.assertNotEqual(hash4, hash5) - def test_recurse_dump_for_function_with_shuffled_globals(self): + def test_recurse_hash_for_function_with_shuffled_globals(self): foo, bar = [0], [1] def func(): @@ -196,10 +215,10 @@ def globalvars_mock2_side_effect(func, *args, **kwargs): return {"bar": bar, "foo": foo} with patch("dill.detect.globalvars", side_effect=globalvars_mock1_side_effect) as globalvars_mock1: - hash1 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash1 = Hasher.hash(func) self.assertGreater(globalvars_mock1.call_count, 0) with patch("dill.detect.globalvars", side_effect=globalvars_mock2_side_effect) as globalvars_mock2: - hash2 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() + hash2 = Hasher.hash(func) self.assertGreater(globalvars_mock2.call_count, 0) self.assertEqual(hash1, hash2) @@ -264,11 +283,11 @@ def test_set_stable(self): def test_set_doesnt_depend_on_order(self): set_ = set("abc") - hash1 = md5(datasets.utils.py_utils.dumps(set_)).hexdigest() + hash1 = Hasher.hash(set_) set_ = set("def") - hash2 = md5(datasets.utils.py_utils.dumps(set_)).hexdigest() + hash2 = Hasher.hash(set_) set_ = set("cba") - hash3 = md5(datasets.utils.py_utils.dumps(set_)).hexdigest() + hash3 = Hasher.hash(set_) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -277,11 +296,11 @@ def test_hash_tiktoken_encoding(self): import tiktoken enc = tiktoken.get_encoding("gpt2") - hash1 = md5(datasets.utils.py_utils.dumps(enc)).hexdigest() + hash1 = Hasher.hash(enc) enc = tiktoken.get_encoding("r50k_base") - hash2 = md5(datasets.utils.py_utils.dumps(enc)).hexdigest() + hash2 = Hasher.hash(enc) enc = tiktoken.get_encoding("gpt2") - hash3 = md5(datasets.utils.py_utils.dumps(enc)).hexdigest() + hash3 = Hasher.hash(enc) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -290,11 +309,11 @@ def test_hash_torch_tensor(self): import torch t = torch.tensor([1.0]) - hash1 = md5(datasets.utils.py_utils.dumps(t)).hexdigest() + hash1 = Hasher.hash(t) t = torch.tensor([2.0]) - hash2 = md5(datasets.utils.py_utils.dumps(t)).hexdigest() + hash2 = Hasher.hash(t) t = torch.tensor([1.0]) - hash3 = md5(datasets.utils.py_utils.dumps(t)).hexdigest() + hash3 = Hasher.hash(t) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -306,14 +325,43 @@ def test_hash_spacy_model(self): import spacy nlp = spacy.load("en_core_web_sm") - hash1 = md5(datasets.utils.py_utils.dumps(nlp)).hexdigest() + hash1 = Hasher.hash(nlp) nlp = spacy.load("fr_core_news_sm") - hash2 = md5(datasets.utils.py_utils.dumps(nlp)).hexdigest() + hash2 = Hasher.hash(nlp) nlp = spacy.load("en_core_web_sm") - hash3 = md5(datasets.utils.py_utils.dumps(nlp)).hexdigest() + hash3 = Hasher.hash(nlp) self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) + @require_not_windows + @require_torch + def test_hash_torch_compiled_function(self): + import torch + + def f(x): + return torch.sin(x) + torch.cos(x) + + hash1 = Hasher.hash(f) + f = torch.compile(f) + hash2 = Hasher.hash(f) + self.assertEqual(hash1, hash2) + + @require_not_windows + @require_torch + def test_hash_torch_compiled_module(self): + m = TorchModule() + next(iter(m.parameters())).data.fill_(1.0) + hash1 = Hasher.hash(m) + m = torch.compile(m) + hash2 = Hasher.hash(m) + m = TorchModule() + next(iter(m.parameters())).data.fill_(2.0) + m = torch.compile(m) + hash3 = Hasher.hash(m) + self.assertEqual(hash1, hash2) + self.assertNotEqual(hash1, hash3) + self.assertNotEqual(hash2, hash3) + @pytest.mark.integration def test_move_script_doesnt_change_hash(tmp_path: Path):