Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3296,7 +3296,7 @@ class NumExamplesMismatchError(Exception):

def validate_function_output(processed_inputs, indices):
"""Validate output of the map function."""
if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table)):
if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame)):
raise TypeError(
f"Provided `function` which is applied to all elements of table returns a variable of type {type(processed_inputs)}. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects."
)
Expand Down Expand Up @@ -3351,7 +3351,7 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
returned_lazy_dict = False
if update_data is None:
# Check if the function returns updated examples
update_data = isinstance(processed_inputs, (Mapping, pa.Table))
update_data = isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame))
validate_function_output(processed_inputs, indices)
if not update_data:
return None # Nothing to update, let's move on
Expand Down Expand Up @@ -3445,6 +3445,8 @@ def init_buffer_and_writer():
stack.enter_context(writer)
if isinstance(example, pa.Table):
writer.write_row(example)
elif isinstance(example, pd.DataFrame):
writer.write_row(pa.Table.from_pandas(example))
else:
writer.write(example)
num_examples_progress_update += 1
Expand Down Expand Up @@ -3476,6 +3478,8 @@ def init_buffer_and_writer():
stack.enter_context(writer)
if isinstance(batch, pa.Table):
writer.write_table(batch)
elif isinstance(batch, pd.DataFrame):
writer.write_table(pa.Table.from_pandas(batch))
else:
writer.write_batch(batch)
num_examples_progress_update += num_examples_in_batch
Expand Down
2 changes: 2 additions & 0 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None):
Args:
row: the row to add.
"""
if len(row) != 1:
raise ValueError(f"Only single-row pyarrow tables are allowed but got table with {len(row)} rows.")
self.current_rows.append(row)
if writer_batch_size is None:
writer_batch_size = self.writer_batch_size
Expand Down
80 changes: 79 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ def test_map_multiprocessing(self, in_memory):
self.assertNotEqual(dset_test._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)

def test_new_features(self, in_memory):
def test_map_new_features(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
features = Features({"filename": Value("string"), "label": ClassLabel(names=["positive", "negative"])})
Expand Down Expand Up @@ -1388,6 +1388,84 @@ def test_map_caching(self, in_memory):
finally:
datasets.enable_caching()

def test_map_return_pa_table(self, in_memory):
def func_return_single_row_pa_table(x):
return pa.table({"id": [0], "text": ["a"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pa_table) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Batched
def func_return_single_row_pa_table_batched(x):
batch_size = len(x[next(iter(x))])
return pa.table({"id": [0] * batch_size, "text": ["a"] * batch_size})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pa_table_batched, batched=True) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Error when returning a table with more than one row in the non-batched mode
def func_return_multi_row_pa_table(x):
return pa.table({"id": [0, 1], "text": ["a", "b"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pa_table)

def test_map_return_pd_dataframe(self, in_memory):
def func_return_single_row_pd_dataframe(x):
return pd.DataFrame({"id": [0], "text": ["a"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pd_dataframe) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Batched
def func_return_single_row_pd_dataframe_batched(x):
batch_size = len(x[next(iter(x))])
return pd.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
with dset.map(func_return_single_row_pd_dataframe_batched, batched=True) as dset_test:
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(
dset_test.features,
Features({"id": Value("int64"), "text": Value("string")}),
)
self.assertEqual(dset_test[0]["id"], 0)
self.assertEqual(dset_test[0]["text"], "a")

# Error when returning a table with more than one row in the non-batched mode
def func_return_multi_row_pd_dataframe(x):
return pd.DataFrame({"id": [0, 1], "text": ["a", "b"]})

with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)

@require_torch
def test_map_torch(self, in_memory):
import torch
Expand Down
4 changes: 2 additions & 2 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def data_dir_with_arrow(tmp_path):
data_dir.mkdir()
output_train = os.path.join(data_dir, "train.arrow")
with ArrowWriter(path=output_train) as writer:
writer.write_row(pa.Table.from_pydict({"col_1": ["foo"] * 10}))
writer.write_table(pa.Table.from_pydict({"col_1": ["foo"] * 10}))
num_examples, num_bytes = writer.finalize()
assert num_examples == 10
assert num_bytes > 0
output_test = os.path.join(data_dir, "test.arrow")
with ArrowWriter(path=output_test) as writer:
writer.write_row(pa.Table.from_pydict({"col_1": ["bar"] * 10}))
writer.write_table(pa.Table.from_pydict({"col_1": ["bar"] * 10}))
num_examples, num_bytes = writer.finalize()
assert num_examples == 10
assert num_bytes > 0
Expand Down