Skip to content

Commit 6358dbd

Browse files
committed
better error message when using the wrong load_from_disk
1 parent 375056d commit 6358dbd

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

src/datasets/arrow_dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,12 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
674674
tmp_dir = tempfile.TemporaryDirectory()
675675
dataset_path = Path(tmp_dir.name, src_dataset_path)
676676
fs.download(src_dataset_path, dataset_path.as_posix(), recursive=True)
677+
dataset_dict_json_path = Path(dataset_path, config.DATASETDICT_JSON_FILENAME).as_posix()
678+
dataset_info_path = Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix()
679+
if not fs.isfile(dataset_info_path) and fs.isfile(dataset_dict_json_path):
680+
raise FileNotFoundError(
681+
f"No such file or directory: '{dataset_info_path}'. Looks like you tried to load a DatasetDict object, not a Dataset. Please use DatasetDict.load_from_disk instead."
682+
)
677683

678684
with open(
679685
Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "r", encoding="utf-8"

src/datasets/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
DATASETDICT_INFOS_FILENAME = "dataset_infos.json"
158158
LICENSE_FILENAME = "LICENSE"
159159
METRIC_INFO_FILENAME = "metric_info.json"
160+
DATASETDICT_JSON_FILENAME = "dataset_dict.json"
160161

161162
MODULE_NAME_FOR_DYNAMIC_MODULES = "datasets_modules"
162163

src/datasets/dataset_dict.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from datasets.utils.doc_utils import is_documented_by
1313

14+
from . import config
1415
from .arrow_dataset import Dataset
1516
from .features import Features
1617
from .filesystems import extract_path_from_uri, is_remote_filesystem
@@ -673,7 +674,7 @@ def save_to_disk(self, dataset_dict_path: str, fs=None):
673674

674675
json.dump(
675676
{"splits": list(self)},
676-
fs.open(Path(dest_dataset_dict_path, "dataset_dict.json").as_posix(), "w", encoding="utf-8"),
677+
fs.open(Path(dest_dataset_dict_path, config.DATASETDICT_JSON_FILENAME).as_posix(), "w", encoding="utf-8"),
677678
)
678679
for k, dataset in self.items():
679680
dataset.save_to_disk(Path(dest_dataset_dict_path, k).as_posix(), fs)
@@ -706,8 +707,14 @@ def load_from_disk(dataset_dict_path: str, fs=None, keep_in_memory: Optional[boo
706707
else:
707708
fs = fsspec.filesystem("file")
708709
dest_dataset_dict_path = dataset_dict_path
710+
dataset_dict_json_path = Path(dest_dataset_dict_path, config.DATASETDICT_JSON_FILENAME).as_posix()
711+
dataset_info_path = Path(dest_dataset_dict_path, config.DATASET_INFO_FILENAME).as_posix()
712+
if fs.isfile(dataset_info_path) and not fs.isfile(dataset_dict_json_path):
713+
raise FileNotFoundError(
714+
f"No such file or directory: '{dataset_dict_json_path}'. Looks like you tried to load a Dataset object, not a DatasetDict. Please use Dataset.load_from_disk instead."
715+
)
709716
for k in json.load(
710-
fs.open(Path(dest_dataset_dict_path, "dataset_dict.json").as_posix(), "r", encoding="utf-8")
717+
fs.open(Path(dest_dataset_dict_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "r", encoding="utf-8")
711718
)["splits"]:
712719
dataset_dict_split_path = (
713720
dataset_dict_path.split("://")[0] + "://" + Path(dest_dataset_dict_path, k).as_posix()

src/datasets/load.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,9 +797,9 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
797797

798798
if not fs.exists(dest_dataset_path):
799799
raise FileNotFoundError("Directory {} not found".format(dataset_path))
800-
if fs.isfile(Path(dest_dataset_path, "dataset_info.json").as_posix()):
800+
if fs.isfile(Path(dest_dataset_path, config.DATASET_INFO_FILENAME).as_posix()):
801801
return Dataset.load_from_disk(dataset_path, fs, keep_in_memory=keep_in_memory)
802-
elif fs.isfile(Path(dest_dataset_path, "dataset_dict.json").as_posix()):
802+
elif fs.isfile(Path(dest_dataset_path, config.DATASETDICT_JSON_FILENAME).as_posix()):
803803
return DatasetDict.load_from_disk(dataset_path, fs, keep_in_memory=keep_in_memory)
804804
else:
805805
raise FileNotFoundError(

0 commit comments

Comments
 (0)