Skip to content

[Caching] Deterministic hashing of torch tensors #5170

@lhoestq

Description

@lhoestq

Currently this fails

import torch
from datasets.fingerprint import Hasher

t = torch.tensor([1.])

def func(x):
    return t + x

hash1 = Hasher.hash(func)
t = torch.tensor([1.])
hash2 = Hasher.hash(func)
assert hash1 == hash2

Also as noticed in https://discuss.huggingface.co/t/dataset-cant-cache-models-outputs/24945, using a model in a map function doesn't work well with caching. Indeed the bert-base-uncased model has a different hash every time you reload it. Supporting torch tensors may also help in this case.

This can be fixed by registering a custom pickling functions for torch tensors - as we did for other objects such as CodeType, FunctionType and Regex in py_utils.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions