Skip to content

Commit 9257c6e

Browse files
authored
Move checks from _map_single to map (#2660)
* Map refactor * Style * Map refactor * Style * Update type hints * Return fn_kwargs check to _map_single * Update type hints in code * Fix removed change * Update signature of DatasetDict.map
1 parent 78b55b7 commit 9257c6e

File tree

2 files changed

+20
-33
lines changed

2 files changed

+20
-33
lines changed

src/datasets/arrow_dataset.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ def map(
15721572
batched: bool = False,
15731573
batch_size: Optional[int] = 1000,
15741574
drop_last_batch: bool = False,
1575-
remove_columns: Optional[List[str]] = None,
1575+
remove_columns: Optional[Union[str, List[str]]] = None,
15761576
keep_in_memory: bool = False,
15771577
load_from_cache_file: bool = None,
15781578
cache_file_name: Optional[str] = None,
@@ -1606,7 +1606,7 @@ def map(
16061606
`batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`.
16071607
drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be
16081608
dropped instead of being processed by the function.
1609-
remove_columns (`Optional[List[str]]`, default `None`): Remove a selection of columns while doing the mapping.
1609+
remove_columns (`Optional[Union[str, List[str]]]`, default `None`): Remove a selection of columns while doing the mapping.
16101610
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
16111611
columns with names in `remove_columns`, these columns will be kept.
16121612
keep_in_memory (:obj:`bool`, default `False`): Keep the dataset in memory instead of writing it to a cache file.
@@ -1631,6 +1631,9 @@ def map(
16311631
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
16321632
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
16331633
"""
1634+
assert (
1635+
not keep_in_memory or cache_file_name is None
1636+
), "Please use either `keep_in_memory` or `cache_file_name` but not both."
16341637
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."
16351638

16361639
# If the array is empty we do nothing
@@ -1652,6 +1655,17 @@ def map(
16521655
)
16531656
)
16541657

1658+
if isinstance(remove_columns, str):
1659+
remove_columns = [remove_columns]
1660+
1661+
if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
1662+
raise ValueError(
1663+
"Column to remove {} not in the dataset. Current columns in the dataset: {}".format(
1664+
list(filter(lambda col: col not in self._data.column_names, remove_columns)),
1665+
self._data.column_names,
1666+
)
1667+
)
1668+
16551669
load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()
16561670

16571671
if fn_kwargs is None:
@@ -1758,7 +1772,7 @@ def _map_single(
17581772
self,
17591773
function: Optional[Callable] = None,
17601774
with_indices: bool = False,
1761-
input_columns: Optional[Union[str, List[str]]] = None,
1775+
input_columns: Optional[List[str]] = None,
17621776
batched: bool = False,
17631777
batch_size: Optional[int] = 1000,
17641778
drop_last_batch: bool = False,
@@ -1787,7 +1801,7 @@ def _map_single(
17871801
- `function(batch: Union[Dict[List], List[Any]], indices: List[int]) -> Union[Dict, Any]` if `batched=True` and `with_indices=True`
17881802
If no function is provided, default to identity function: lambda x: x
17891803
with_indices (:obj:`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.
1790-
input_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): The columns to be passed into `function` as
1804+
input_columns (`Optional[List[str]]`, defaults to `None`): The columns to be passed into `function` as
17911805
positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument.
17921806
batched (:obj:`bool`, defaults to `False`): Provide batch of examples to `function`
17931807
batch_size (`Optional[int]`, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True`
@@ -1816,10 +1830,6 @@ def _map_single(
18161830
disable_tqdm (:obj:`bool`, defaults to `False`): Whether to silence tqdm's output.
18171831
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
18181832
"""
1819-
assert (
1820-
not keep_in_memory or cache_file_name is None
1821-
), "Please use either `keep_in_memory` or `cache_file_name` but not both."
1822-
18231833
# Reduce logging to keep things readable in multiprocessing with tqdm
18241834
if rank is not None and logging.get_verbosity() < logging.WARNING:
18251835
logging.set_verbosity_warning()
@@ -1828,29 +1838,6 @@ def _map_single(
18281838
if rank is not None and not disable_tqdm and "notebook" in tqdm.__name__:
18291839
print(" ", end="", flush=True)
18301840

1831-
# Select the columns (arrow columns) to process
1832-
if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
1833-
raise ValueError(
1834-
"Column to remove {} not in the dataset. Current columns in the dataset: {}".format(
1835-
list(filter(lambda col: col not in self._data.column_names, remove_columns)),
1836-
self._data.column_names,
1837-
)
1838-
)
1839-
1840-
load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()
1841-
1842-
if isinstance(input_columns, str):
1843-
input_columns = [input_columns]
1844-
1845-
if input_columns is not None:
1846-
for input_column in input_columns:
1847-
if input_column not in self._data.column_names:
1848-
raise ValueError(
1849-
"Input column {} not in the dataset. Current columns in the dataset: {}".format(
1850-
input_column, self._data.column_names
1851-
)
1852-
)
1853-
18541841
if fn_kwargs is None:
18551842
fn_kwargs = {}
18561843

src/datasets/dataset_dict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def map(
417417
input_columns: Optional[Union[str, List[str]]] = None,
418418
batched: bool = False,
419419
batch_size: Optional[int] = 1000,
420-
remove_columns: Optional[List[str]] = None,
420+
remove_columns: Optional[Union[str, List[str]]] = None,
421421
keep_in_memory: bool = False,
422422
load_from_cache_file: bool = True,
423423
cache_file_names: Optional[Dict[str, Optional[str]]] = None,
@@ -444,7 +444,7 @@ def map(
444444
batched (`bool`, defaults to `False`): Provide batch of examples to `function`
445445
batch_size (`Optional[int]`, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True`
446446
`batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`
447-
remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
447+
remove_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
448448
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
449449
columns with names in `remove_columns`, these columns will be kept.
450450
keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file.

0 commit comments

Comments
 (0)