Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,7 @@ def map(
suffix_template: str = "_{rank:05d}_of_{num_proc:05d}",
new_fingerprint: Optional[str] = None,
desc: Optional[str] = None,
try_original_type: Optional[bool] = True,
) -> "Dataset":
"""
Apply a function to all the examples in the table (individually or in batches) and update the table.
Expand Down Expand Up @@ -2932,6 +2933,9 @@ def map(
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
desc (`str`, *optional*, defaults to `None`):
Meaningful description to be displayed alongside with the progress bar while mapping examples.
try_original_type (`Optional[bool]`, defaults to `True`):
Try to keep the types of the original columns (e.g. int32 -> int32).
Set to False if you want to always infer new types.

Example:

Expand Down Expand Up @@ -3022,6 +3026,7 @@ def map(
"features": features,
"disable_nullable": disable_nullable,
"fn_kwargs": fn_kwargs,
"try_original_type": try_original_type,
}

if new_fingerprint is None:
Expand Down Expand Up @@ -3216,6 +3221,7 @@ def _map_single(
new_fingerprint: Optional[str] = None,
rank: Optional[int] = None,
offset: int = 0,
try_original_type: Optional[bool] = True,
) -> Iterable[tuple[int, bool, Union[int, "Dataset"]]]:
"""Apply a function to all the elements in the table (individually or in batches)
and update the table (if function does update examples).
Expand Down Expand Up @@ -3257,6 +3263,9 @@ def _map_single(
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments
rank: (`int`, optional, defaults to `None`): If specified, this is the process rank when doing multiprocessing
offset: (`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True`.
try_original_type: (`Optional[bool]`, defaults to `True`):
Try to keep the types of the original columns (e.g. int32 -> int32).
Set to False if you want to always infer new types.
"""
if fn_kwargs is None:
fn_kwargs = {}
Expand Down Expand Up @@ -3528,7 +3537,7 @@ def iter_outputs(shard_iterable):
):
writer.write_table(batch.to_arrow())
else:
writer.write_batch(batch)
writer.write_batch(batch, try_original_type=try_original_type)
num_examples_progress_update += num_examples_in_batch
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:
_time = time.time()
Expand Down
8 changes: 7 additions & 1 deletion src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,15 @@ def write_batch(
self,
batch_examples: dict[str, list],
writer_batch_size: Optional[int] = None,
try_original_type: Optional[bool] = True,
):
"""Write a batch of Example to file.
Ignores the batch if it appears to be empty,
preventing a potential schema update of unknown types.

Args:
batch_examples: the batch of examples to add.
try_original_type: use `try_type` when instantiating OptimizedTypedSequence if `True`, otherwise `try_type = None`.
"""
if batch_examples and len(next(iter(batch_examples.values()))) == 0:
return
Expand All @@ -615,7 +617,11 @@ def write_batch(
arrays.append(array)
inferred_features[col] = generate_from_arrow_type(col_values.type)
else:
col_try_type = try_features[col] if try_features is not None and col in try_features else None
col_try_type = (
try_features[col]
if try_features is not None and col in try_features and try_original_type
else None
)
typed_sequence = OptimizedTypedSequence(col_values, type=col_type, try_type=col_try_type, col=col)
arrays.append(pa.array(typed_sequence))
inferred_features[col] = typed_sequence.get_inferred_type()
Expand Down
35 changes: 34 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ def inject_fixtures(self, caplog, set_sqlalchemy_silence_uber_warning):
self._caplog = caplog

def _create_dummy_dataset(
self, in_memory: bool, tmp_dir: str, multiple_columns=False, array_features=False, nested_features=False
self,
in_memory: bool,
tmp_dir: str,
multiple_columns=False,
array_features=False,
nested_features=False,
int_to_float=False,
) -> Dataset:
assert int(multiple_columns) + int(array_features) + int(nested_features) < 2
if multiple_columns:
Expand All @@ -151,6 +157,12 @@ def _create_dummy_dataset(
data = {"nested": [{"a": i, "x": i * 10, "c": i * 100} for i in range(1, 11)]}
features = Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}})
dset = Dataset.from_dict(data, features=features)
elif int_to_float:
data = {
"text": ["text1", "text2", "text3", "text4"],
"labels": [[1, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0]],
}
dset = Dataset.from_dict(data)
else:
dset = Dataset.from_dict({"filename": ["my_name-train" + "_" + str(x) for x in np.arange(30).tolist()]})
if not in_memory:
Expand Down Expand Up @@ -1123,6 +1135,27 @@ def func(x, i):
self.assertListEqual(sorted(dset_test[0].keys()), ["col_1", "col_1_plus_one"])
self.assertListEqual(sorted(dset_test.column_names), ["col_1", "col_1_plus_one", "col_2", "col_3"])
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
# casting int labels to float labels
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, int_to_float=True) as dset:

def _preprocess(examples):
result = {"labels": [list(map(float, labels)) for labels in examples["labels"]]}
return result

with dset.map(
_preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True
) as dset_test:
for labels in dset_test["labels"]:
for label in labels:
self.assertIsInstance(label, int)

with dset.map(
_preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False
) as dset_test:
for labels in dset_test["labels"]:
for label in labels:
self.assertIsInstance(label, float)

def test_map_multiprocessing(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir: # standard
Expand Down
Loading