Skip to content

Commit 7733241

Browse files
authored
Better cast error when generating dataset (#6509)
* better cats error * minor * move to exceptions.py and add help message
1 parent 3329be8 commit 7733241

File tree

5 files changed

+166
-24
lines changed

5 files changed

+166
-24
lines changed

src/datasets/builder.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .download.download_manager import DownloadManager, DownloadMode
5353
from .download.mock_download_manager import MockDownloadManager
5454
from .download.streaming_download_manager import StreamingDownloadManager, xopen
55+
from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError
5556
from .features import Features
5657
from .filesystems import (
5758
is_remote_filesystem,
@@ -64,6 +65,7 @@
6465
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
6566
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
6667
from .streaming import extend_dataset_builder_for_streaming
68+
from .table import CastError
6769
from .utils import logging
6870
from .utils import tqdm as hf_tqdm
6971
from .utils._filelock import FileLock
@@ -80,6 +82,7 @@
8082
temporary_assignment,
8183
)
8284
from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs
85+
from .utils.track import tracked_list
8386

8487

8588
logger = logging.get_logger(__name__)
@@ -89,22 +92,6 @@ class InvalidConfigName(ValueError):
8992
pass
9093

9194

92-
class DatasetBuildError(Exception):
93-
pass
94-
95-
96-
class ManualDownloadError(DatasetBuildError):
97-
pass
98-
99-
100-
class DatasetGenerationError(DatasetBuildError):
101-
pass
102-
103-
104-
class FileFormatError(DatasetBuildError):
105-
pass
106-
107-
10895
@dataclass
10996
class BuilderConfig:
11097
"""Base class for `DatasetBuilder` data configuration.
@@ -1895,6 +1882,7 @@ def _rename_shard(shard_id_and_job: Tuple[int]):
18951882
def _prepare_split_single(
18961883
self, gen_kwargs: dict, fpath: str, file_format: str, max_shard_size: int, job_id: int
18971884
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
1885+
gen_kwargs = {k: tracked_list(v) if isinstance(v, list) else v for k, v in gen_kwargs.items()}
18981886
generator = self._generate_tables(**gen_kwargs)
18991887
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
19001888
embed_local_files = file_format == "parquet"
@@ -1928,7 +1916,15 @@ def _prepare_split_single(
19281916
storage_options=self._fs.storage_options,
19291917
embed_local_files=embed_local_files,
19301918
)
1931-
writer.write_table(table)
1919+
try:
1920+
writer.write_table(table)
1921+
except CastError as cast_error:
1922+
raise DatasetGenerationCastError.from_cast_error(
1923+
cast_error=cast_error,
1924+
builder_name=self.info.builder_name,
1925+
gen_kwargs=gen_kwargs,
1926+
token=self.token,
1927+
)
19321928
num_examples_progress_update += len(table)
19331929
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
19341930
_time = time.time()
@@ -1946,6 +1942,8 @@ def _prepare_split_single(
19461942
# Ignore the writer's error for no examples written to the file if this error was caused by the error in _generate_examples before the first example was yielded
19471943
if isinstance(e, SchemaInferenceError) and e.__context__ is not None:
19481944
e = e.__context__
1945+
if isinstance(e, DatasetGenerationError):
1946+
raise
19491947
raise DatasetGenerationError("An error occurred while generating the dataset") from e
19501948

19511949
yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)

src/datasets/download/download_manager.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from datetime import datetime
2626
from functools import partial
2727
from itertools import chain
28-
from typing import Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
28+
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
2929

3030
from .. import config
3131
from ..utils import tqdm as hf_tqdm
@@ -34,6 +34,7 @@
3434
from ..utils.info_utils import get_size_checksum_dict
3535
from ..utils.logging import get_logger
3636
from ..utils.py_utils import NestedDataStructure, map_nested, size_str
37+
from ..utils.track import TrackedIterable, tracked_str
3738
from .download_config import DownloadConfig
3839

3940

@@ -147,16 +148,20 @@ def _get_extraction_protocol(path: str) -> Optional[str]:
147148
return _get_extraction_protocol_with_magic_number(f)
148149

149150

150-
class _IterableFromGenerator(Iterable):
151+
class _IterableFromGenerator(TrackedIterable):
151152
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""
152153

153154
def __init__(self, generator: Callable, *args, **kwargs):
155+
super().__init__()
154156
self.generator = generator
155157
self.args = args
156158
self.kwargs = kwargs
157159

158160
def __iter__(self):
159-
yield from self.generator(*self.args, **self.kwargs)
161+
for x in self.generator(*self.args, **self.kwargs):
162+
self.last_item = x
163+
yield x
164+
self.last_item = None
160165

161166

162167
class ArchiveIterable(_IterableFromGenerator):
@@ -443,7 +448,10 @@ def _download(self, url_or_filename: str, download_config: DownloadConfig) -> st
443448
if is_relative_path(url_or_filename):
444449
# append the relative path to the base_path
445450
url_or_filename = url_or_path_join(self._base_path, url_or_filename)
446-
return cached_path(url_or_filename, download_config=download_config)
451+
out = cached_path(url_or_filename, download_config=download_config)
452+
out = tracked_str(out)
453+
out.set_origin(url_or_filename)
454+
return out
447455

448456
def iter_archive(self, path_or_buf: Union[str, io.BufferedReader]):
449457
"""Iterate over files within an archive.
@@ -526,8 +534,10 @@ def extract(self, path_or_paths, num_proc="deprecated"):
526534
# Extract downloads the file first if it is not already downloaded
527535
if download_config.download_desc is None:
528536
download_config.download_desc = "Downloading data"
537+
538+
extract_func = partial(self._download, download_config=download_config)
529539
extracted_paths = map_nested(
530-
partial(cached_path, download_config=download_config),
540+
extract_func,
531541
path_or_paths,
532542
num_proc=download_config.num_proc,
533543
desc="Extracting data files",

src/datasets/exceptions.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright 2023 The HuggingFace Authors.
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
from huggingface_hub import HfFileSystem
6+
7+
from . import config
8+
from .table import CastError
9+
from .utils.track import TrackedIterable, tracked_list, tracked_str
310

411

512
class DatasetsError(Exception):
@@ -25,3 +32,54 @@ class DatasetNotFoundError(FileNotFoundDatasetsError):
2532
- a missing dataset, or
2633
- a private/gated dataset and the user is not authenticated.
2734
"""
35+
36+
37+
class DatasetBuildError(DatasetsError):
38+
pass
39+
40+
41+
class ManualDownloadError(DatasetBuildError):
42+
pass
43+
44+
45+
class FileFormatError(DatasetBuildError):
46+
pass
47+
48+
49+
class DatasetGenerationError(DatasetBuildError):
50+
pass
51+
52+
53+
class DatasetGenerationCastError(DatasetGenerationError):
54+
@classmethod
55+
def from_cast_error(
56+
cls,
57+
cast_error: CastError,
58+
builder_name: str,
59+
gen_kwargs: Dict[str, Any],
60+
token: Optional[Union[bool, str]],
61+
) -> "DatasetGenerationCastError":
62+
explanation_message = (
63+
f"\n\nAll the data files must have the same columns, but at some point {cast_error.details()}"
64+
)
65+
formatted_tracked_gen_kwargs: List[str] = []
66+
for gen_kwarg in gen_kwargs.values():
67+
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterable)):
68+
continue
69+
while isinstance(gen_kwarg, (tracked_list, TrackedIterable)) and gen_kwarg.last_item is not None:
70+
gen_kwarg = gen_kwarg.last_item
71+
if isinstance(gen_kwarg, tracked_str):
72+
gen_kwarg = gen_kwarg.get_origin()
73+
if isinstance(gen_kwarg, str) and gen_kwarg.startswith("hf://"):
74+
resolved_path = HfFileSystem(endpoint=config.HF_ENDPOINT, token=token).resolve_path(gen_kwarg)
75+
gen_kwarg = "hf://" + resolved_path.unresolve()
76+
if "@" + resolved_path.revision in gen_kwarg:
77+
gen_kwarg = (
78+
gen_kwarg.replace("@" + resolved_path.revision, "", 1)
79+
+ f" (at revision {resolved_path.revision})"
80+
)
81+
formatted_tracked_gen_kwargs.append(str(gen_kwarg))
82+
if formatted_tracked_gen_kwargs:
83+
explanation_message += f"\n\nThis happened while the {builder_name} dataset builder was generating data using\n\n{', '.join(formatted_tracked_gen_kwargs)}"
84+
help_message = "\n\nPlease either edit the data files to have matching columns, or separate them into different configurations (see docs at https://hf.co/docs/hub/datasets-manual-configuration#multiple-configurations)"
85+
return cls("An error occurred while generating the dataset" + explanation_message + help_message)

src/datasets/table.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,6 +2216,25 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
22162216
raise TypeError(f"Couldn't embed array of type\n{array.type}\nwith\n{feature}")
22172217

22182218

2219+
class CastError(ValueError):
2220+
"""When it's not possible to cast an Arrow table to a specific schema or set of features"""
2221+
2222+
def __init__(self, *args, table_column_names: List[str], requested_column_names: List[str]) -> None:
2223+
super().__init__(*args)
2224+
self.table_column_names = table_column_names
2225+
self.requested_column_names = requested_column_names
2226+
2227+
def details(self):
2228+
new_columns = set(self.table_column_names) - set(self.requested_column_names)
2229+
missing_columns = set(self.requested_column_names) - set(self.table_column_names)
2230+
if new_columns and missing_columns:
2231+
return f"there are {len(new_columns)} new columns ({', '.join(new_columns)}) and {len(missing_columns)} missing columns ({', '.join(missing_columns)})."
2232+
elif new_columns:
2233+
return f"there are {len(new_columns)} new columns ({new_columns})"
2234+
else:
2235+
return f"there are {len(missing_columns)} missing columns ({missing_columns})"
2236+
2237+
22192238
def cast_table_to_features(table: pa.Table, features: "Features"):
22202239
"""Cast a table to the arrow schema that corresponds to the requested features.
22212240
@@ -2229,7 +2248,11 @@ def cast_table_to_features(table: pa.Table, features: "Features"):
22292248
table (`pyarrow.Table`): the casted table
22302249
"""
22312250
if sorted(table.column_names) != sorted(features):
2232-
raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match")
2251+
raise CastError(
2252+
f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match",
2253+
table_column_names=table.column_names,
2254+
requested_column_names=list(features),
2255+
)
22332256
arrays = [cast_array_to_feature(table[name], feature) for name, feature in features.items()]
22342257
return pa.Table.from_arrays(arrays, schema=features.arrow_schema)
22352258

@@ -2250,7 +2273,11 @@ def cast_table_to_schema(table: pa.Table, schema: pa.Schema):
22502273

22512274
features = Features.from_arrow_schema(schema)
22522275
if sorted(table.column_names) != sorted(features):
2253-
raise ValueError(f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match")
2276+
raise CastError(
2277+
f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match",
2278+
table_column_names=table.column_names,
2279+
requested_column_names=list(features),
2280+
)
22542281
arrays = [cast_array_to_feature(table[name], feature) for name, feature in features.items()]
22552282
return pa.Table.from_arrays(arrays, schema=schema)
22562283

src/datasets/utils/track.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from collections.abc import Iterator
2+
from typing import Iterable
3+
4+
5+
class tracked_str(str):
6+
origins = {}
7+
8+
def set_origin(self, origin: str):
9+
if super().__repr__() not in self.origins:
10+
self.origins[super().__repr__()] = origin
11+
12+
def get_origin(self):
13+
return self.origins.get(super().__repr__(), str(self))
14+
15+
def __repr__(self) -> str:
16+
if super().__repr__() not in self.origins or self.origins[super().__repr__()] == self:
17+
return super().__repr__()
18+
else:
19+
return f"{str(self)} (origin={self.origins[super().__repr__()]})"
20+
21+
22+
class tracked_list(list):
23+
def __init__(self, *args, **kwargs) -> None:
24+
super().__init__(*args, **kwargs)
25+
self.last_item = None
26+
27+
def __iter__(self) -> Iterator:
28+
for x in super().__iter__():
29+
self.last_item = x
30+
yield x
31+
self.last_item = None
32+
33+
def __repr__(self) -> str:
34+
if self.last_item is None:
35+
return super().__repr__()
36+
else:
37+
return f"{self.__class__.__name__}(current={self.last_item})"
38+
39+
40+
class TrackedIterable(Iterable):
41+
def __init__(self) -> None:
42+
super().__init__()
43+
self.last_item = None
44+
45+
def __repr__(self) -> str:
46+
if self.last_item is None:
47+
super().__repr__()
48+
else:
49+
return f"{self.__class__.__name__}(current={self.last_item})"

0 commit comments

Comments
 (0)