Skip to content
Merged
52 changes: 18 additions & 34 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ def map(
batched: bool = False,
batch_size: Optional[int] = 1000,
drop_last_batch: bool = False,
remove_columns: Optional[List[str]] = None,
remove_columns: Optional[Union[str, List[str]]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = None,
cache_file_name: Optional[str] = None,
Expand Down Expand Up @@ -1602,7 +1602,7 @@ def map(
`batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`.
drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be
dropped instead of being processed by the function.
remove_columns (`Optional[List[str]]`, default `None`): Remove a selection of columns while doing the mapping.
remove_columns (`Optional[Union[str, List[str]]]`, default `None`): Remove a selection of columns while doing the mapping.
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
columns with names in `remove_columns`, these columns will be kept.
keep_in_memory (:obj:`bool`, default `False`): Keep the dataset in memory instead of writing it to a cache file.
Expand All @@ -1627,6 +1627,9 @@ def map(
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
"""
assert (
not keep_in_memory or cache_file_name is None
), "Please use either `keep_in_memory` or `cache_file_name` but not both."
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."

# If the array is empty we do nothing
Expand All @@ -1648,6 +1651,17 @@ def map(
)
)

if isinstance(remove_columns, str):
remove_columns = [remove_columns]

if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
raise ValueError(
"Column to remove {} not in the dataset. Current columns in the dataset: {}".format(
list(filter(lambda col: col not in self._data.column_names, remove_columns)),
self._data.column_names,
)
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

if fn_kwargs is None:
Expand Down Expand Up @@ -1751,7 +1765,7 @@ def _map_single(
batched: bool = False,
batch_size: Optional[int] = 1000,
drop_last_batch: bool = False,
remove_columns: Optional[List[str]] = None,
remove_columns: Optional[Union[str, List[str]]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = None,
cache_file_name: Optional[str] = None,
Expand Down Expand Up @@ -1782,7 +1796,7 @@ def _map_single(
`batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`
drop_last_batch (:obj:`bool`, default: `False`): Whether a last batch smaller than the batch_size should be
dropped instead of being processed by the function.
remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
remove_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
columns with names in `remove_columns`, these columns will be kept.
keep_in_memory (:obj:`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file.
Expand All @@ -1803,10 +1817,6 @@ def _map_single(
offset: (:obj:`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True`
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
"""
assert (
not keep_in_memory or cache_file_name is None
), "Please use either `keep_in_memory` or `cache_file_name` but not both."

# Reduce logging to keep things readable in multiprocessing with tqdm
if rank is not None and logging.get_verbosity() < logging.WARNING:
logging.set_verbosity_warning()
Expand All @@ -1815,32 +1825,6 @@ def _map_single(
if rank is not None and "notebook" in tqdm.__name__:
print(" ", end="", flush=True)

# Select the columns (arrow columns) to process
if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
raise ValueError(
"Column to remove {} not in the dataset. Current columns in the dataset: {}".format(
list(filter(lambda col: col not in self._data.column_names, remove_columns)),
self._data.column_names,
)
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

if isinstance(input_columns, str):
input_columns = [input_columns]

if input_columns is not None:
for input_column in input_columns:
if input_column not in self._data.column_names:
raise ValueError(
"Input column {} not in the dataset. Current columns in the dataset: {}".format(
input_column, self._data.column_names
)
)

if fn_kwargs is None:
fn_kwargs = {}

# If we do batch computation but no batch size is provided, default to the full dataset
if batched and (batch_size is None or batch_size <= 0):
batch_size = self.num_rows
Expand Down