Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 18 additions & 31 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,7 +1572,7 @@ def map(
batched: bool = False,
batch_size: Optional[int] = 1000,
drop_last_batch: bool = False,
remove_columns: Optional[List[str]] = None,
remove_columns: Optional[Union[str, List[str]]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = None,
cache_file_name: Optional[str] = None,
Expand Down Expand Up @@ -1606,7 +1606,7 @@ def map(
`batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`.
drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be
dropped instead of being processed by the function.
remove_columns (`Optional[List[str]]`, default `None`): Remove a selection of columns while doing the mapping.
remove_columns (`Optional[Union[str, List[str]]]`, default `None`): Remove a selection of columns while doing the mapping.
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
columns with names in `remove_columns`, these columns will be kept.
keep_in_memory (:obj:`bool`, default `False`): Keep the dataset in memory instead of writing it to a cache file.
Expand All @@ -1631,6 +1631,9 @@ def map(
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
"""
assert (
not keep_in_memory or cache_file_name is None
), "Please use either `keep_in_memory` or `cache_file_name` but not both."
assert num_proc is None or num_proc > 0, "num_proc must be an integer > 0."

# If the array is empty we do nothing
Expand All @@ -1652,6 +1655,17 @@ def map(
)
)

if isinstance(remove_columns, str):
remove_columns = [remove_columns]

if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
raise ValueError(
"Column to remove {} not in the dataset. Current columns in the dataset: {}".format(
list(filter(lambda col: col not in self._data.column_names, remove_columns)),
self._data.column_names,
)
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

if fn_kwargs is None:
Expand Down Expand Up @@ -1758,7 +1772,7 @@ def _map_single(
self,
function: Optional[Callable] = None,
with_indices: bool = False,
input_columns: Optional[Union[str, List[str]]] = None,
input_columns: Optional[List[str]] = None,
batched: bool = False,
batch_size: Optional[int] = 1000,
drop_last_batch: bool = False,
Expand Down Expand Up @@ -1787,7 +1801,7 @@ def _map_single(
- `function(batch: Union[Dict[List], List[Any]], indices: List[int]) -> Union[Dict, Any]` if `batched=True` and `with_indices=True`
If no function is provided, default to identity function: lambda x: x
with_indices (:obj:`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.
input_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): The columns to be passed into `function` as
input_columns (`Optional[List[str]]`, defaults to `None`): The columns to be passed into `function` as
positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument.
batched (:obj:`bool`, defaults to `False`): Provide batch of examples to `function`
batch_size (`Optional[int]`, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True`
Expand Down Expand Up @@ -1816,10 +1830,6 @@ def _map_single(
disable_tqdm (:obj:`bool`, defaults to `False`): Whether to silence tqdm's output.
desc (`Optional[str]`, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples.
"""
assert (
not keep_in_memory or cache_file_name is None
), "Please use either `keep_in_memory` or `cache_file_name` but not both."

# Reduce logging to keep things readable in multiprocessing with tqdm
if rank is not None and logging.get_verbosity() < logging.WARNING:
logging.set_verbosity_warning()
Expand All @@ -1828,29 +1838,6 @@ def _map_single(
if rank is not None and not disable_tqdm and "notebook" in tqdm.__name__:
print(" ", end="", flush=True)

# Select the columns (arrow columns) to process
if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
raise ValueError(
"Column to remove {} not in the dataset. Current columns in the dataset: {}".format(
list(filter(lambda col: col not in self._data.column_names, remove_columns)),
self._data.column_names,
)
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

if isinstance(input_columns, str):
input_columns = [input_columns]

if input_columns is not None:
for input_column in input_columns:
if input_column not in self._data.column_names:
raise ValueError(
"Input column {} not in the dataset. Current columns in the dataset: {}".format(
input_column, self._data.column_names
)
)

if fn_kwargs is None:
fn_kwargs = {}

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def map(
input_columns: Optional[Union[str, List[str]]] = None,
batched: bool = False,
batch_size: Optional[int] = 1000,
remove_columns: Optional[List[str]] = None,
remove_columns: Optional[Union[str, List[str]]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
cache_file_names: Optional[Dict[str, Optional[str]]] = None,
Expand All @@ -444,7 +444,7 @@ def map(
batched (`bool`, defaults to `False`): Provide batch of examples to `function`
batch_size (`Optional[int]`, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True`
`batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`
remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
remove_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
columns with names in `remove_columns`, these columns will be kept.
keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file.
Expand Down