diff --git a/src/datasets/builder.py b/src/datasets/builder.py index a078cb4c2c8..93ce534a928 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -52,6 +52,7 @@ from .download.download_manager import DownloadManager, DownloadMode from .download.mock_download_manager import MockDownloadManager from .download.streaming_download_manager import StreamingDownloadManager, xopen +from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError from .features import Features from .filesystems import ( is_remote_filesystem, @@ -64,6 +65,7 @@ from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase from .splits import Split, SplitDict, SplitGenerator, SplitInfo from .streaming import extend_dataset_builder_for_streaming +from .table import CastError from .utils import logging from .utils import tqdm as hf_tqdm from .utils._filelock import FileLock @@ -80,6 +82,7 @@ temporary_assignment, ) from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs +from .utils.track import tracked_list logger = logging.get_logger(__name__) @@ -89,22 +92,6 @@ class InvalidConfigName(ValueError): pass -class DatasetBuildError(Exception): - pass - - -class ManualDownloadError(DatasetBuildError): - pass - - -class DatasetGenerationError(DatasetBuildError): - pass - - -class FileFormatError(DatasetBuildError): - pass - - @dataclass class BuilderConfig: """Base class for `DatasetBuilder` data configuration. @@ -1895,6 +1882,7 @@ def _rename_shard(shard_id_and_job: Tuple[int]): def _prepare_split_single( self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int ) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: + gen_kwargs = {k: tracked_list(v) if isinstance(v, list) else v for k, v in gen_kwargs.items()} generator = self._generate_tables(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter embed_local_files = file_format == "parquet" @@ -1928,7 +1916,15 @@ def _prepare_split_single( storage_options=self._fs.storage_options, embed_local_files=embed_local_files, ) - writer.write_table(table) + try: + writer.write_table(table) + except CastError as cast_error: + raise DatasetGenerationCastError.from_cast_error( + cast_error=cast_error, + builder_name=self.info.builder_name, + gen_kwargs=gen_kwargs, + token=self.token, + ) num_examples_progress_update += len(table) if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL: _time = time.time() @@ -1946,6 +1942,8 @@ def _prepare_split_single( # Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded if isinstance(e, SchemaInferenceError) and e.__context__ is not None: e = e.__context__ + if isinstance(e, DatasetGenerationError): + raise raise DatasetGenerationError("An error occurred while generating the dataset") from e yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py index d4fe3c9fafa..9dcda45262d 100644 --- a/src/datasets/download/download_manager.py +++ b/src/datasets/download/download_manager.py @@ -25,7 +25,7 @@ from datetime import datetime from functools import partial from itertools import chain -from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Generator, List, Optional, Tuple, Union from .. import config from ..utils import tqdm as hf_tqdm @@ -34,6 +34,7 @@ from ..utils.info_utils import get_size_checksum_dict from ..utils.logging import get_logger from ..utils.py_utils import NestedDataStructure, map_nested, size_str +from ..utils.track import TrackedIterable, tracked_str from .download_config import DownloadConfig @@ -147,16 +148,20 @@ def _get_extraction_protocol(path: str) -> Optional[str]: return _get_extraction_protocol_with_magic_number(f) -class _IterableFromGenerator(Iterable): +class _IterableFromGenerator(TrackedIterable): """Utility class to create an iterable from a generator function, in order to reset the generator when needed.""" def __init__(self, generator: Callable, *args, **kwargs): + super().__init__() self.generator = generator self.args = args self.kwargs = kwargs def __iter__(self): - yield from self.generator(*self.args, **self.kwargs) + for x in self.generator(*self.args, **self.kwargs): + self.last_item = x + yield x + self.last_item = None class ArchiveIterable(_IterableFromGenerator): @@ -443,7 +448,10 @@ def _download(self, url_or_filename: str, download_config: DownloadConfig) -> st if is_relative_path(url_or_filename): # append the relative path to the base_path url_or_filename = url_or_path_join(self._base_path, url_or_filename) - return cached_path(url_or_filename, download_config=download_config) + out = cached_path(url_or_filename, download_config=download_config) + out = tracked_str(out) + out.set_origin(url_or_filename) + return out def iter_archive(self, path_or_buf: Union[str, io.BufferedReader]): """Iterate over files within an archive. @@ -526,8 +534,10 @@ def extract(self, path_or_paths, num_proc="deprecated"): # Extract downloads the file first if it is not already downloaded if download_config.download_desc is None: download_config.download_desc = "Downloading data" + + extract_func = partial(self._download, download_config=download_config) extracted_paths = map_nested( - partial(cached_path, download_config=download_config), + extract_func, path_or_paths, num_proc=download_config.num_proc, desc="Extracting data files", diff --git a/src/datasets/exceptions.py b/src/datasets/exceptions.py index a6a7aa1acf9..619f2a10117 100644 --- a/src/datasets/exceptions.py +++ b/src/datasets/exceptions.py @@ -1,5 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright 2023 The HuggingFace Authors. +from typing import Any, Dict, List, Optional, Union + +from huggingface_hub import HfFileSystem + +from . import config +from .table import CastError +from .utils.track import TrackedIterable, tracked_list, tracked_str class DatasetsError(Exception): @@ -25,3 +32,54 @@ class DatasetNotFoundError(FileNotFoundDatasetsError): - a missing dataset, or - a private/gated dataset and the user is not authenticated. """ + + +class DatasetBuildError(DatasetsError): + pass + + +class ManualDownloadError(DatasetBuildError): + pass + + +class FileFormatError(DatasetBuildError): + pass + + +class DatasetGenerationError(DatasetBuildError): + pass + + +class DatasetGenerationCastError(DatasetGenerationError): + @classmethod + def from_cast_error( + cls, + cast_error: CastError, + builder_name: str, + gen_kwargs: Dict[str, Any], + token: Optional[Union[bool, str]], + ) -> "DatasetGenerationCastError": + explanation_message = ( + f"\n\nAll the data files must have the same columns, but at some point {cast_error.details()}" + ) + formatted_tracked_gen_kwargs: List[str] = [] + for gen_kwarg in gen_kwargs.values(): + if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterable)): + continue + while isinstance(gen_kwarg, (tracked_list, TrackedIterable)) and gen_kwarg.last_item is not None: + gen_kwarg = gen_kwarg.last_item + if isinstance(gen_kwarg, tracked_str): + gen_kwarg = gen_kwarg.get_origin() + if isinstance(gen_kwarg, str) and gen_kwarg.startswith("hf://"): + resolved_path = HfFileSystem(endpoint=config.HF_ENDPOINT, token=token).resolve_path(gen_kwarg) + gen_kwarg = "hf://" + resolved_path.unresolve() + if "@" + resolved_path.revision in gen_kwarg: + gen_kwarg = ( + gen_kwarg.replace("@" + resolved_path.revision, "", 1) + + f" (at revision {resolved_path.revision})" + ) + formatted_tracked_gen_kwargs.append(str(gen_kwarg)) + if formatted_tracked_gen_kwargs: + explanation_message += f"\n\nThis happened while the {builder_name} dataset builder was generating data using\n\n{', '.join(formatted_tracked_gen_kwargs)}" + help_message = "\n\nPlease either edit the data files to have matching columns, or separate them into different configurations (see docs at https://hf.co/docs/hub/datasets-manual-configuration#multiple-configurations)" + return cls("An error occurred while generating the dataset" + explanation_message + help_message) diff --git a/src/datasets/table.py b/src/datasets/table.py index 1a2df53d456..763716b6415 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -2216,6 +2216,25 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"): raise TypeError(f"Couldn't embed array of type\n{array.type}\nwith\n{feature}") +class CastError(ValueError): + """When it's not possible to cast an Arrow table to a specific schema or set of features""" + + def __init__(self, *args, table_column_names: List[str], requested_column_names: List[str]) -> None: + super().__init__(*args) + self.table_column_names = table_column_names + self.requested_column_names = requested_column_names + + def details(self): + new_columns = set(self.table_column_names) - set(self.requested_column_names) + missing_columns = set(self.requested_column_names) - set(self.table_column_names) + if new_columns and missing_columns: + return f"there are {len(new_columns)} new columns ({', '.join(new_columns)}) and {len(missing_columns)} missing columns ({', '.join(missing_columns)})." + elif new_columns: + return f"there are {len(new_columns)} new columns ({new_columns})" + else: + return f"there are {len(missing_columns)} missing columns ({missing_columns})" + + def cast_table_to_features(table: pa.Table, features: "Features"): """Cast a table to the arrow schema that corresponds to the requested features. @@ -2229,7 +2248,11 @@ def cast_table_to_features(table: pa.Table, features: "Features"): table (`pyarrow.Table`): the casted table """ if sorted(table.column_names) != sorted(features): - raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match") + raise CastError( + f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match", + table_column_names=table.column_names, + requested_column_names=list(features), + ) arrays = [cast_array_to_feature(table[name], feature) for name, feature in features.items()] return pa.Table.from_arrays(arrays, schema=features.arrow_schema) @@ -2250,7 +2273,11 @@ def cast_table_to_schema(table: pa.Table, schema: pa.Schema): features = Features.from_arrow_schema(schema) if sorted(table.column_names) != sorted(features): - raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match") + raise CastError( + f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match", + table_column_names=table.column_names, + requested_column_names=list(features), + ) arrays = [cast_array_to_feature(table[name], feature) for name, feature in features.items()] return pa.Table.from_arrays(arrays, schema=schema) diff --git a/src/datasets/utils/track.py b/src/datasets/utils/track.py new file mode 100644 index 00000000000..11a3787c7d8 --- /dev/null +++ b/src/datasets/utils/track.py @@ -0,0 +1,49 @@ +from collections.abc import Iterator +from typing import Iterable + + +class tracked_str(str): + origins = {} + + def set_origin(self, origin: str): + if super().__repr__() not in self.origins: + self.origins[super().__repr__()] = origin + + def get_origin(self): + return self.origins.get(super().__repr__(), str(self)) + + def __repr__(self) -> str: + if super().__repr__() not in self.origins or self.origins[super().__repr__()] == self: + return super().__repr__() + else: + return f"{str(self)} (origin={self.origins[super().__repr__()]})" + + +class tracked_list(list): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.last_item = None + + def __iter__(self) -> Iterator: + for x in super().__iter__(): + self.last_item = x + yield x + self.last_item = None + + def __repr__(self) -> str: + if self.last_item is None: + return super().__repr__() + else: + return f"{self.__class__.__name__}(current={self.last_item})" + + +class TrackedIterable(Iterable): + def __init__(self) -> None: + super().__init__() + self.last_item = None + + def __repr__(self) -> str: + if self.last_item is None: + super().__repr__() + else: + return f"{self.__class__.__name__}(current={self.last_item})"