From d50e81f7884894a32e468f7c96f954c73495bb72 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 25 Jun 2024 18:09:54 +0200 Subject: [PATCH] minor fix for bfloat16 --- src/datasets/utils/_dill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py index 2dedf7f1fbc..2a414459266 100644 --- a/src/datasets/utils/_dill.py +++ b/src/datasets/utils/_dill.py @@ -165,7 +165,7 @@ def _save_torchTensor(pickler, obj): def create_torchTensor(np_array, dtype=None): tensor = torch.from_numpy(np_array) if dtype: - tensor = tensor.type(torch.bfloat16) + tensor = tensor.type(dtype) return tensor log(pickler, f"To: {obj}")