diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py index 2dedf7f1fbc..2a414459266 100644 --- a/src/datasets/utils/_dill.py +++ b/src/datasets/utils/_dill.py @@ -165,7 +165,7 @@ def _save_torchTensor(pickler, obj): def create_torchTensor(np_array, dtype=None): tensor = torch.from_numpy(np_array) if dtype: - tensor = tensor.type(torch.bfloat16) + tensor = tensor.type(dtype) return tensor log(pickler, f"To: {obj}")