Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def analyze(files_or_archives, downloaded_files_or_dirs, split):

# Check that all metadata files share the same format
metadata_ext = {
os.path.splitext(downloaded_metadata_file)[1][1:]
for _, downloaded_metadata_file in itertools.chain.from_iterable(metadata_files.values())
os.path.splitext(original_metadata_file)[-1]
for original_metadata_file, _ in itertools.chain.from_iterable(metadata_files.values())
}
if len(metadata_ext) > 1:
raise ValueError(f"Found metadata files with different extensions: {list(metadata_ext)}")
metadata_ext = metadata_ext.pop()

for _, downloaded_metadata_file in itertools.chain.from_iterable(metadata_files.values()):
pa_metadata_table = self._read_metadata(downloaded_metadata_file)
pa_metadata_table = self._read_metadata(downloaded_metadata_file, metadata_ext=metadata_ext)
features_per_metadata_file.append(
(downloaded_metadata_file, datasets.Features.from_arrow_schema(pa_metadata_table.schema))
)
Expand Down Expand Up @@ -236,9 +236,8 @@ def _split_files_and_archives(self, data_files):
archives.append(data_file)
return files, archives

def _read_metadata(self, metadata_file):
metadata_file_ext = os.path.splitext(metadata_file)[1][1:]
if metadata_file_ext == "csv":
def _read_metadata(self, metadata_file, metadata_ext: str = ""):
if metadata_ext == ".csv":
# Use `pd.read_csv` (although slower) instead of `pyarrow.csv.read_csv` for reading CSV files for consistency with the CSV packaged module
return pa.Table.from_pandas(pd.read_csv(metadata_file))
else:
Expand All @@ -255,10 +254,10 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
metadata_dict = None
downloaded_metadata_file = None

metadata_ext = ""
if split_metadata_files:
metadata_ext = {
os.path.splitext(downloaded_metadata_file)[1][1:]
for _, downloaded_metadata_file in split_metadata_files
os.path.splitext(original_metadata_file)[-1] for original_metadata_file, _ in split_metadata_files
}
metadata_ext = metadata_ext.pop()

Expand Down Expand Up @@ -290,7 +289,9 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
_, metadata_file, downloaded_metadata_file = min(
metadata_file_candidates, key=lambda x: count_path_segments(x[0])
)
pa_metadata_table = self._read_metadata(downloaded_metadata_file)
pa_metadata_table = self._read_metadata(
downloaded_metadata_file, metadata_ext=metadata_ext
)
pa_file_name_array = pa_metadata_table["file_name"]
pa_metadata_table = pa_metadata_table.drop(["file_name"])
metadata_dir = os.path.dirname(metadata_file)
Expand All @@ -302,7 +303,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
}
else:
raise ValueError(
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
)
if metadata_dir is not None and downloaded_metadata_file is not None:
file_relpath = os.path.relpath(original_file, metadata_dir)
Expand All @@ -314,7 +315,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_metadata = metadata_dict[file_relpath]
else:
raise ValueError(
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_file_or_dir}."
)
else:
sample_metadata = {}
Expand Down Expand Up @@ -356,7 +357,9 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
_, metadata_file, downloaded_metadata_file = min(
metadata_file_candidates, key=lambda x: count_path_segments(x[0])
)
pa_metadata_table = self._read_metadata(downloaded_metadata_file)
pa_metadata_table = self._read_metadata(
downloaded_metadata_file, metadata_ext=metadata_ext
)
pa_file_name_array = pa_metadata_table["file_name"]
pa_metadata_table = pa_metadata_table.drop(["file_name"])
metadata_dir = os.path.dirname(downloaded_metadata_file)
Expand All @@ -368,7 +371,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
}
else:
raise ValueError(
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
)
if metadata_dir is not None and downloaded_metadata_file is not None:
downloaded_dir_file_relpath = os.path.relpath(downloaded_dir_file, metadata_dir)
Expand All @@ -380,7 +383,7 @@ def _generate_examples(self, files, metadata_files, split_name, add_metadata, ad
sample_metadata = metadata_dict[downloaded_dir_file_relpath]
else:
raise ValueError(
f"One or several metadata.{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
f"One or several metadata{metadata_ext} were found, but not in the same directory or in a parent directory of {downloaded_dir_file}."
)
else:
sample_metadata = {}
Expand Down
71 changes: 71 additions & 0 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import fnmatch
import gc
import os
import shutil
import tempfile
import textwrap
import time
import unittest
from io import BytesIO
Expand All @@ -17,13 +19,18 @@
ClassLabel,
Dataset,
DatasetDict,
DownloadManager,
Features,
Image,
Value,
load_dataset,
load_dataset_builder,
)
from datasets.config import METADATA_CONFIGS_FIELD
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
FolderBasedBuilder,
FolderBasedBuilderConfig,
)
from datasets.utils.file_utils import cached_path
from datasets.utils.hub import hf_hub_url
from tests.fixtures.hub import CI_HUB_ENDPOINT, CI_HUB_USER, CI_HUB_USER_TOKEN
Expand Down Expand Up @@ -813,3 +820,67 @@ def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporar
ds_another_config_builder.config.data_files["random"][0],
"*/another_config/random-00000-of-00001.parquet",
)


class DummyFolderBasedBuilder(FolderBasedBuilder):
BASE_FEATURE = dict
BASE_COLUMN_NAME = "base"
BUILDER_CONFIG_CLASS = FolderBasedBuilderConfig
EXTENSIONS = [".txt"]
# CLASSIFICATION_TASK = TextClassification(text_column="base", label_column="label")


@pytest.fixture(params=[".jsonl", ".csv"])
def text_file_with_metadata(request, tmp_path, text_file):
metadata_filename_extension = request.param
data_dir = tmp_path / "data_dir"
data_dir.mkdir()
text_file_path = data_dir / "file.txt"
shutil.copyfile(text_file, text_file_path)
metadata_file_path = data_dir / f"metadata{metadata_filename_extension}"
metadata = textwrap.dedent(
"""\
{"file_name": "file.txt", "additional_feature": "Dummy file"}
"""
if metadata_filename_extension == ".jsonl"
else """\
file_name,additional_feature
file.txt,Dummy file
"""
)
with open(metadata_file_path, "w", encoding="utf-8") as f:
f.write(metadata)
return text_file_path, metadata_file_path


@for_all_test_methods(xfail_if_500_502_http_error)
@pytest.mark.usefixtures("ci_hub_config", "ci_hfh_hf_hub_url")
class TestLoadFromHub:
_api = HfApi(endpoint=CI_HUB_ENDPOINT)
_token = CI_HUB_USER_TOKEN

def test_load_dataset_with_metadata_file(self, temporary_repo, text_file_with_metadata, tmp_path):
text_file_path, metadata_file_path = text_file_with_metadata
data_dir_path = text_file_path.parent
cache_dir_path = tmp_path / ".cache"
cache_dir_path.mkdir()
with temporary_repo() as repo_id:
self._api.create_repo(repo_id, token=self._token, repo_type="dataset")
self._api.upload_folder(
folder_path=str(data_dir_path),
repo_id=repo_id,
repo_type="dataset",
token=self._token,
)
data_files = [
f"hf://datasets/{repo_id}/{text_file_path.name}",
f"hf://datasets/{repo_id}/{metadata_file_path.name}",
]
builder = DummyFolderBasedBuilder(
dataset_name=repo_id.split("/")[-1], data_files=data_files, cache_dir=str(cache_dir_path)
)
download_manager = DownloadManager()
gen_kwargs = builder._split_generators(download_manager)[0].gen_kwargs
generator = builder._generate_examples(**gen_kwargs)
result = [example for _, example in generator]
assert len(result) == 1