Skip to content

Commit c3a8a87

Browse files
Fix loading Hub datasets with CSV metadata file (#6316)
* Test load dataset with CSV metadata from Hub * Pass metadata_ext to FolderBasedBuilder._read_metadata
1 parent 2b19f6b commit c3a8a87

File tree

2 files changed

+88
-14
lines changed

2 files changed

+88
-14
lines changed

src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,15 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):
164164

165165
# Check that all metadata files share the same format
166166
metadata_ext = {
167-
os.path.splitext(downloaded_metadata_file)[1][1:]
168-
for _, downloaded_metadata_file in itertools.chain.from_iterable(metadata_files.values())
167+
os.path.splitext(original_metadata_file)[-1]
168+
for original_metadata_file, _ in itertools.chain.from_iterable(metadata_files.values())
169169
}
170170
if len(metadata_ext) > 1:
171171
raise ValueError(f"Found metadata files with different extensions: {list(metadata_ext)}")
172172
metadata_ext = metadata_ext.pop()
173173

174174
for _, downloaded_metadata_file in itertools.chain.from_iterable(metadata_files.values()):
175-
pa_metadata_table = self._read_metadata(downloaded_metadata_file)
175+
pa_metadata_table = self._read_metadata(downloaded_metadata_file, metadata_ext=metadata_ext)
176176
features_per_metadata_file.append(
177177
(downloaded_metadata_file, datasets.Features.from_arrow_schema(pa_metadata_table.schema))
178178
)
@@ -236,9 +236,8 @@ def _split_files_and_archives(self, data_files):
236236
archives.append(data_file)
237237
return files, archives
238238

239-
def _read_metadata(self, metadata_file):
240-
metadata_file_ext = os.path.splitext(metadata_file)[1][1:]
241-
if metadata_file_ext == "csv":
239+
def _read_metadata(self, metadata_file, metadata_ext: str = ""):
240+
if metadata_ext == ".csv":
242241
# Use `pd.read_csv` (although slower) instead of `pyarrow.csv.read_csv` for reading CSV files for consistency with the CSV packaged module
243242
return pa.Table.from_pandas(pd.read_csv(metadata_file))
244243
else:
@@ -255,10 +254,10 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
255254
metadata_dict = None
256255
downloaded_metadata_file = None
257256

257+
metadata_ext = ""
258258
if split_metadata_files:
259259
metadata_ext = {
260-
os.path.splitext(downloaded_metadata_file)[1][1:]
261-
for _, downloaded_metadata_file in split_metadata_files
260+
os.path.splitext(original_metadata_file)[-1] for original_metadata_file, _ in split_metadata_files
262261
}
263262
metadata_ext = metadata_ext.pop()
264263

@@ -290,7 +289,9 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
290289
_, metadata_file, downloaded_metadata_file = min(
291290
metadata_file_candidates, key=lambda x: count_path_segments(x[0])
292291
)
293-
pa_metadata_table = self._read_metadata(downloaded_metadata_file)
292+
pa_metadata_table = self._read_metadata(
293+
downloaded_metadata_file, metadata_ext=metadata_ext
294+
)
294295
pa_file_name_array = pa_metadata_table["file_name"]
295296
pa_metadata_table = pa_metadata_table.drop(["file_name"])
296297
metadata_dir = os.path.dirname(metadata_file)
@@ -302,7 +303,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
302303
}
303304
else:
304305
raise ValueError(
305-
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
306+
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
306307
)
307308
if metadata_dir is not None and downloaded_metadata_file is not None:
308309
file_relpath = os.path.relpath(original_file, metadata_dir)
@@ -314,7 +315,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
314315
sample_metadata = metadata_dict[file_relpath]
315316
else:
316317
raise ValueError(
317-
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
318+
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
318319
)
319320
else:
320321
sample_metadata = {}
@@ -356,7 +357,9 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
356357
_, metadata_file, downloaded_metadata_file = min(
357358
metadata_file_candidates, key=lambda x: count_path_segments(x[0])
358359
)
359-
pa_metadata_table = self._read_metadata(downloaded_metadata_file)
360+
pa_metadata_table = self._read_metadata(
361+
downloaded_metadata_file, metadata_ext=metadata_ext
362+
)
360363
pa_file_name_array = pa_metadata_table["file_name"]
361364
pa_metadata_table = pa_metadata_table.drop(["file_name"])
362365
metadata_dir = os.path.dirname(downloaded_metadata_file)
@@ -368,7 +371,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
368371
}
369372
else:
370373
raise ValueError(
371-
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
374+
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
372375
)
373376
if metadata_dir is not None and downloaded_metadata_file is not None:
374377
downloaded_dir_file_relpath = os.path.relpath(downloaded_dir_file, metadata_dir)
@@ -380,7 +383,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
380383
sample_metadata = metadata_dict[downloaded_dir_file_relpath]
381384
else:
382385
raise ValueError(
383-
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
386+
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
384387
)
385388
else:
386389
sample_metadata = {}

tests/test_upstream_hub.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import fnmatch
22
import gc
33
import os
4+
import shutil
45
import tempfile
6+
import textwrap
57
import time
68
import unittest
79
from io import BytesIO
@@ -17,13 +19,18 @@
1719
ClassLabel,
1820
Dataset,
1921
DatasetDict,
22+
DownloadManager,
2023
Features,
2124
Image,
2225
Value,
2326
load_dataset,
2427
load_dataset_builder,
2528
)
2629
from datasets.config import METADATA_CONFIGS_FIELD
30+
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
31+
FolderBasedBuilder,
32+
FolderBasedBuilderConfig,
33+
)
2734
from datasets.utils.file_utils import cached_path
2835
from datasets.utils.hub import hf_hub_url
2936
from tests.fixtures.hub import CI_HUB_ENDPOINT, CI_HUB_USER, CI_HUB_USER_TOKEN
@@ -813,3 +820,67 @@ def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporar
813820
ds_another_config_builder.config.data_files["random"][0],
814821
"*/another_config/random-00000-of-00001.parquet",
815822
)
823+
824+
825+
class DummyFolderBasedBuilder(FolderBasedBuilder):
826+
BASE_FEATURE = dict
827+
BASE_COLUMN_NAME = "base"
828+
BUILDER_CONFIG_CLASS = FolderBasedBuilderConfig
829+
EXTENSIONS = [".txt"]
830+
# CLASSIFICATION_TASK = TextClassification(text_column="base", label_column="label")
831+
832+
833+
@pytest.fixture(params=[".jsonl", ".csv"])
834+
def text_file_with_metadata(request, tmp_path, text_file):
835+
metadata_filename_extension = request.param
836+
data_dir = tmp_path / "data_dir"
837+
data_dir.mkdir()
838+
text_file_path = data_dir / "file.txt"
839+
shutil.copyfile(text_file, text_file_path)
840+
metadata_file_path = data_dir / f"metadata{metadata_filename_extension}"
841+
metadata = textwrap.dedent(
842+
"""\
843+
{"file_name": "file.txt", "additional_feature": "Dummy file"}
844+
"""
845+
if metadata_filename_extension == ".jsonl"
846+
else """\
847+
file_name,additional_feature
848+
file.txt,Dummy file
849+
"""
850+
)
851+
with open(metadata_file_path, "w", encoding="utf-8") as f:
852+
f.write(metadata)
853+
return text_file_path, metadata_file_path
854+
855+
856+
@for_all_test_methods(xfail_if_500_502_http_error)
857+
@pytest.mark.usefixtures("ci_hub_config", "ci_hfh_hf_hub_url")
858+
class TestLoadFromHub:
859+
_api = HfApi(endpoint=CI_HUB_ENDPOINT)
860+
_token = CI_HUB_USER_TOKEN
861+
862+
def test_load_dataset_with_metadata_file(self, temporary_repo, text_file_with_metadata, tmp_path):
863+
text_file_path, metadata_file_path = text_file_with_metadata
864+
data_dir_path = text_file_path.parent
865+
cache_dir_path = tmp_path / ".cache"
866+
cache_dir_path.mkdir()
867+
with temporary_repo() as repo_id:
868+
self._api.create_repo(repo_id, token=self._token, repo_type="dataset")
869+
self._api.upload_folder(
870+
folder_path=str(data_dir_path),
871+
repo_id=repo_id,
872+
repo_type="dataset",
873+
token=self._token,
874+
)
875+
data_files = [
876+
f"hf://datasets/{repo_id}/{text_file_path.name}",
877+
f"hf://datasets/{repo_id}/{metadata_file_path.name}",
878+
]
879+
builder = DummyFolderBasedBuilder(
880+
dataset_name=repo_id.split("/")[-1], data_files=data_files, cache_dir=str(cache_dir_path)
881+
)
882+
download_manager = DownloadManager()
883+
gen_kwargs = builder._split_generators(download_manager)[0].gen_kwargs
884+
generator = builder._generate_examples(**gen_kwargs)
885+
result = [example for _, example in generator]
886+
assert len(result) == 1

0 commit comments

Comments
 (0)