Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
"tensorflow>=2.3,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
"tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",
"tiktoken",
"torch",
"torch>=2.0.0",
"soundfile>=0.12.1",
"transformers",
"typing-extensions>=4.6.1", # due to conflict between apache-beam and pydantic
Expand Down
37 changes: 5 additions & 32 deletions src/datasets/fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import json
import os
import random
import shutil
Expand All @@ -10,15 +9,12 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pyarrow as pa
import xxhash

from .info import DatasetInfo
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH
from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table
from .utils._dill import dumps
from .utils.deprecation_utils import deprecated
from .utils.logging import get_logger
from .utils.py_utils import asdict, dumps


if TYPE_CHECKING:
Expand Down Expand Up @@ -199,6 +195,7 @@ def cleanup_func():
#################


@deprecated("Use `copyreg.pickle` to register a custom reducer.")
def hashregister(*types):
def proxy(func):
for t in types:
Expand All @@ -225,15 +222,13 @@ def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str:
return m.hexdigest()

@classmethod
@deprecated("Use `Hasher.hash` instead.")
def hash_default(cls, value: Any) -> str:
return cls.hash_bytes(dumps(value))
return cls.hash(value)

@classmethod
def hash(cls, value: Any) -> str:
if type(value) in cls.dispatch:
return cls.dispatch[type(value)](cls, value)
else:
return cls.hash_default(value)
return cls.hash_bytes(dumps(value))

def update(self, value: Any) -> None:
header_for_update = f"=={type(value)}=="
Expand All @@ -245,28 +240,6 @@ def hexdigest(self) -> str:
return self.m.hexdigest()


# Register a new hasher can be useful for two possible reasons:
# 1 - optimize the hashing of large amount of data (e.g. pa.Table)
# 2 - take advantage of a custom serialization method (e.g. DatasetInfo)


@hashregister(pa.Table, Table, InMemoryTable, MemoryMappedTable, ConcatenationTable)
def _hash_pa_table(hasher, value):
def _hash_pa_array(value):
if isinstance(value, pa.ChunkedArray):
return hasher.hash_bytes(c.to_string().encode("utf-8") for c in value.chunks)
else:
return hasher.hash_bytes(value.to_string().encode("utf-8"))

value = "-".join(col + "-" + _hash_pa_array(value[col]) for col in sorted(value.column_names))
return hasher.hash_bytes(value.encode("utf-8"))


@hashregister(DatasetInfo)
def _hash_dataset_info(hasher, value):
return hasher.hash_bytes(json.dumps(asdict(value), sort_keys=True).encode("utf-8"))


#################
# Fingerprinting
#################
Expand Down
Loading