Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union

import fsspec
import pyarrow as pa
Expand Down Expand Up @@ -64,6 +64,7 @@
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
from .streaming import extend_dataset_builder_for_streaming
from .table import CastError
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils._filelock import FileLock
Expand All @@ -80,6 +81,7 @@
temporary_assignment,
)
from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs
from .utils.track import TrackedIterable, tracked_list, tracked_str


logger = logging.get_logger(__name__)
Expand All @@ -105,6 +107,14 @@ class FileFormatError(DatasetBuildError):
pass


class CastErrorDuringDatasetGeneration(CastError):
def __init__(
self, *args, table_column_names: List[str], requested_column_names: List[str], gen_kwargs: Dict[str, Any]
) -> None:
super().__init__(*args, table_column_names=table_column_names, requested_column_names=requested_column_names)
self.gen_kwargs = gen_kwargs


@dataclass
class BuilderConfig:
"""Base class for `DatasetBuilder` data configuration.
Expand Down Expand Up @@ -1895,6 +1905,7 @@ def _rename_shard(shard_id_and_job: Tuple[int]):
def _prepare_split_single(
self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
gen_kwargs = {k: tracked_list(v) if isinstance(v, list) else v for k, v in gen_kwargs.items()}
generator = self._generate_tables(**gen_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"
Expand Down Expand Up @@ -1928,7 +1939,15 @@ def _prepare_split_single(
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
writer.write_table(table)
try:
writer.write_table(table)
except CastError as cast_error:
raise CastErrorDuringDatasetGeneration(
str(cast_error),
table_column_names=cast_error.table_column_names,
requested_column_names=cast_error.requested_column_names,
gen_kwargs=gen_kwargs,
) from None
num_examples_progress_update += len(table)
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
_time = time.time()
Expand All @@ -1946,7 +1965,23 @@ def _prepare_split_single(
# Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded
if isinstance(e, SchemaInferenceError) and e.__context__ is not None:
e = e.__context__
raise DatasetGenerationError("An error occurred while generating the dataset") from e
if isinstance(e, CastErrorDuringDatasetGeneration):
additional_error = (
f"\n\nAll the data files must have the same columns, but at some point {e.details()}"
)
tracked_gen_kwargs = [
v for v in e.gen_kwargs.values() if isinstance(v, (tracked_str, tracked_list, TrackedIterable))
]
formatted_tracked_gen_kwargs: List[str] = []
for gen_kwarg in tracked_gen_kwargs:
while isinstance(gen_kwarg, (tracked_list, TrackedIterable)) and gen_kwarg.last_item is not None:
gen_kwarg = gen_kwarg.last_item
formatted_tracked_gen_kwargs.append(repr(gen_kwarg))
if tracked_gen_kwargs:
additional_error += f"\n\nThis happened while the {self.__class__.__name__} dataset builder was generating data using\n\n{', '.join(formatted_tracked_gen_kwargs)}"
else:
additional_error = ""
raise DatasetGenerationError("An error occurred while generating the dataset" + additional_error) from e

yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)

Expand Down
20 changes: 15 additions & 5 deletions src/datasets/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from datetime import datetime
from functools import partial
from itertools import chain
from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union

from .. import config
from ..utils import tqdm as hf_tqdm
Expand All @@ -34,6 +34,7 @@
from ..utils.info_utils import get_size_checksum_dict
from ..utils.logging import get_logger
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
from ..utils.track import TrackedIterable, tracked_str
from .download_config import DownloadConfig


Expand Down Expand Up @@ -147,16 +148,20 @@ def _get_extraction_protocol(path: str) -> Optional[str]:
return _get_extraction_protocol_with_magic_number(f)


class _IterableFromGenerator(Iterable):
class _IterableFromGenerator(TrackedIterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

def __init__(self, generator: Callable, *args, **kwargs):
super().__init__()
self.generator = generator
self.args = args
self.kwargs = kwargs

def __iter__(self):
yield from self.generator(*self.args, **self.kwargs)
for x in self.generator(*self.args, **self.kwargs):
self.last_item = x
yield x
self.last_item = None


class ArchiveIterable(_IterableFromGenerator):
Expand Down Expand Up @@ -443,7 +448,10 @@ def _download(self, url_or_filename: str, download_config: DownloadConfig) -> st
if is_relative_path(url_or_filename):
# append the relative path to the base_path
url_or_filename = url_or_path_join(self._base_path, url_or_filename)
return cached_path(url_or_filename, download_config=download_config)
out = cached_path(url_or_filename, download_config=download_config)
out = tracked_str(out)
out.set_origin(url_or_filename)
return out

def iter_archive(self, path_or_buf: Union[str, io.BufferedReader]):
"""Iterate over files within an archive.
Expand Down Expand Up @@ -526,8 +534,10 @@ def extract(self, path_or_paths, num_proc="deprecated"):
# Extract downloads the file first if it is not already downloaded
if download_config.download_desc is None:
download_config.download_desc = "Downloading data"

extract_func = partial(self._download, download_config=download_config)
extracted_paths = map_nested(
partial(cached_path, download_config=download_config),
extract_func,
path_or_paths,
num_proc=download_config.num_proc,
desc="Extracting data files",
Expand Down
31 changes: 29 additions & 2 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,25 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
raise TypeError(f"Couldn't embed array of type\n{array.type}\nwith\n{feature}")


class CastError(ValueError):
"""When it's not possible to cast an Arrow table to a specific schema or set of features"""

def __init__(self, *args, table_column_names: List[str], requested_column_names: List[str]) -> None:
super().__init__(*args)
self.table_column_names = table_column_names
self.requested_column_names = requested_column_names

def details(self):
new_columns = set(self.table_column_names) - set(self.requested_column_names)
missing_columns = set(self.requested_column_names) - set(self.table_column_names)
if new_columns and missing_columns:
return f"there are {len(new_columns)} new columns ({', '.join(new_columns)}) and {len(missing_columns)} missing columns ({', '.join(missing_columns)})."
elif new_columns:
return f"there are {len(new_columns)} new columns ({new_columns})"
else:
return f"there are {len(missing_columns)} missing columns ({missing_columns})"


def cast_table_to_features(table: pa.Table, features: "Features"):
"""Cast a table to the arrow schema that corresponds to the requested features.

Expand All @@ -2229,7 +2248,11 @@ def cast_table_to_features(table: pa.Table, features: "Features"):
table (`pyarrow.Table`): the casted table
"""
if sorted(table.column_names) != sorted(features):
raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match")
raise CastError(
f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match",
table_column_names=table.column_names,
requested_column_names=list(features),
)
arrays = [cast_array_to_feature(table[name], feature) for name, feature in features.items()]
return pa.Table.from_arrays(arrays, schema=features.arrow_schema)

Expand All @@ -2250,7 +2273,11 @@ def cast_table_to_schema(table: pa.Table, schema: pa.Schema):

features = Features.from_arrow_schema(schema)
if sorted(table.column_names) != sorted(features):
raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match")
raise CastError(
f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match",
table_column_names=table.column_names,
requested_column_names=list(features),
)
arrays = [cast_array_to_feature(table[name], feature) for name, feature in features.items()]
return pa.Table.from_arrays(arrays, schema=schema)

Expand Down
46 changes: 46 additions & 0 deletions src/datasets/utils/track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from collections.abc import Iterator
from typing import Iterable


class tracked_str(str):
origins = {}

def set_origin(self, origin: str):
if super().__repr__() not in self.origins:
self.origins[super().__repr__()] = origin

def __repr__(self) -> str:
if super().__repr__() not in self.origins or self.origins[super().__repr__()] == self:
return super().__repr__()
else:
return f"{str(self)} (origin={self.origins[super().__repr__()]})"


class tracked_list(list):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.last_item = None

def __iter__(self) -> Iterator:
for x in super().__iter__():
self.last_item = x
yield x
self.last_item = None

def __repr__(self) -> str:
if self.last_item is None:
return super().__repr__()
else:
return f"{self.__class__.__name__}(current={self.last_item})"


class TrackedIterable(Iterable):
def __init__(self) -> None:
super().__init__()
self.last_item = None

def __repr__(self) -> str:
if self.last_item is None:
super().__repr__()
else:
return f"{self.__class__.__name__}(current={self.last_item})"