Skip to content

Commit 0d2b885

Browse files
authored
Support returning dataframe in map transform (#5995)
* Support returning dataframe in map transform * Fix test
1 parent 6f3f38d commit 0d2b885

File tree

4 files changed

+89
-5
lines changed

4 files changed

+89
-5
lines changed

src/datasets/arrow_dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3296,7 +3296,7 @@ class NumExamplesMismatchError(Exception):
32963296

32973297
def validate_function_output(processed_inputs, indices):
32983298
"""Validate output of the map function."""
3299-
if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table)):
3299+
if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame)):
33003300
raise TypeError(
33013301
f"Provided `function` which is applied to all elements of table returns a variable of type {type(processed_inputs)}. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects."
33023302
)
@@ -3351,7 +3351,7 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
33513351
returned_lazy_dict = False
33523352
if update_data is None:
33533353
# Check if the function returns updated examples
3354-
update_data = isinstance(processed_inputs, (Mapping, pa.Table))
3354+
update_data = isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame))
33553355
validate_function_output(processed_inputs, indices)
33563356
if not update_data:
33573357
return None # Nothing to update, let's move on
@@ -3445,6 +3445,8 @@ def init_buffer_and_writer():
34453445
stack.enter_context(writer)
34463446
if isinstance(example, pa.Table):
34473447
writer.write_row(example)
3448+
elif isinstance(example, pd.DataFrame):
3449+
writer.write_row(pa.Table.from_pandas(example))
34483450
else:
34493451
writer.write(example)
34503452
num_examples_progress_update += 1
@@ -3476,6 +3478,8 @@ def init_buffer_and_writer():
34763478
stack.enter_context(writer)
34773479
if isinstance(batch, pa.Table):
34783480
writer.write_table(batch)
3481+
elif isinstance(batch, pd.DataFrame):
3482+
writer.write_table(pa.Table.from_pandas(batch))
34793483
else:
34803484
writer.write_batch(batch)
34813485
num_examples_progress_update += num_examples_in_batch

src/datasets/arrow_writer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@ def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None):
510510
Args:
511511
row: the row to add.
512512
"""
513+
if len(row) != 1:
514+
raise ValueError(f"Only single-row pyarrow tables are allowed but got table with {len(row)} rows.")
513515
self.current_rows.append(row)
514516
if writer_batch_size is None:
515517
writer_batch_size = self.writer_batch_size

tests/test_arrow_dataset.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ def test_map_multiprocessing(self, in_memory):
11901190
self.assertNotEqual(dset_test._fingerprint, fingerprint)
11911191
assert_arrow_metadata_are_synced_with_dataset_features(dset_test)
11921192

1193-
def test_new_features(self, in_memory):
1193+
def test_map_new_features(self, in_memory):
11941194
with tempfile.TemporaryDirectory() as tmp_dir:
11951195
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
11961196
features = Features({"filename": Value("string"), "label": ClassLabel(names=["positive", "negative"])})
@@ -1397,6 +1397,84 @@ def test_map_caching(self, in_memory):
13971397
finally:
13981398
datasets.enable_caching()
13991399

1400+
def test_map_return_pa_table(self, in_memory):
1401+
def func_return_single_row_pa_table(x):
1402+
return pa.table({"id": [0], "text": ["a"]})
1403+
1404+
with tempfile.TemporaryDirectory() as tmp_dir:
1405+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1406+
with dset.map(func_return_single_row_pa_table) as dset_test:
1407+
self.assertEqual(len(dset_test), 30)
1408+
self.assertDictEqual(
1409+
dset_test.features,
1410+
Features({"id": Value("int64"), "text": Value("string")}),
1411+
)
1412+
self.assertEqual(dset_test[0]["id"], 0)
1413+
self.assertEqual(dset_test[0]["text"], "a")
1414+
1415+
# Batched
1416+
def func_return_single_row_pa_table_batched(x):
1417+
batch_size = len(x[next(iter(x))])
1418+
return pa.table({"id": [0] * batch_size, "text": ["a"] * batch_size})
1419+
1420+
with tempfile.TemporaryDirectory() as tmp_dir:
1421+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1422+
with dset.map(func_return_single_row_pa_table_batched, batched=True) as dset_test:
1423+
self.assertEqual(len(dset_test), 30)
1424+
self.assertDictEqual(
1425+
dset_test.features,
1426+
Features({"id": Value("int64"), "text": Value("string")}),
1427+
)
1428+
self.assertEqual(dset_test[0]["id"], 0)
1429+
self.assertEqual(dset_test[0]["text"], "a")
1430+
1431+
# Error when returning a table with more than one row in the non-batched mode
1432+
def func_return_multi_row_pa_table(x):
1433+
return pa.table({"id": [0, 1], "text": ["a", "b"]})
1434+
1435+
with tempfile.TemporaryDirectory() as tmp_dir:
1436+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1437+
self.assertRaises(ValueError, dset.map, func_return_multi_row_pa_table)
1438+
1439+
def test_map_return_pd_dataframe(self, in_memory):
1440+
def func_return_single_row_pd_dataframe(x):
1441+
return pd.DataFrame({"id": [0], "text": ["a"]})
1442+
1443+
with tempfile.TemporaryDirectory() as tmp_dir:
1444+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1445+
with dset.map(func_return_single_row_pd_dataframe) as dset_test:
1446+
self.assertEqual(len(dset_test), 30)
1447+
self.assertDictEqual(
1448+
dset_test.features,
1449+
Features({"id": Value("int64"), "text": Value("string")}),
1450+
)
1451+
self.assertEqual(dset_test[0]["id"], 0)
1452+
self.assertEqual(dset_test[0]["text"], "a")
1453+
1454+
# Batched
1455+
def func_return_single_row_pd_dataframe_batched(x):
1456+
batch_size = len(x[next(iter(x))])
1457+
return pd.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size})
1458+
1459+
with tempfile.TemporaryDirectory() as tmp_dir:
1460+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1461+
with dset.map(func_return_single_row_pd_dataframe_batched, batched=True) as dset_test:
1462+
self.assertEqual(len(dset_test), 30)
1463+
self.assertDictEqual(
1464+
dset_test.features,
1465+
Features({"id": Value("int64"), "text": Value("string")}),
1466+
)
1467+
self.assertEqual(dset_test[0]["id"], 0)
1468+
self.assertEqual(dset_test[0]["text"], "a")
1469+
1470+
# Error when returning a table with more than one row in the non-batched mode
1471+
def func_return_multi_row_pd_dataframe(x):
1472+
return pd.DataFrame({"id": [0, 1], "text": ["a", "b"]})
1473+
1474+
with tempfile.TemporaryDirectory() as tmp_dir:
1475+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1476+
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)
1477+
14001478
@require_torch
14011479
def test_map_torch(self, in_memory):
14021480
import torch

tests/test_load.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ def data_dir_with_arrow(tmp_path):
115115
data_dir.mkdir()
116116
output_train = os.path.join(data_dir, "train.arrow")
117117
with ArrowWriter(path=output_train) as writer:
118-
writer.write_row(pa.Table.from_pydict({"col_1": ["foo"] * 10}))
118+
writer.write_table(pa.Table.from_pydict({"col_1": ["foo"] * 10}))
119119
num_examples, num_bytes = writer.finalize()
120120
assert num_examples == 10
121121
assert num_bytes > 0
122122
output_test = os.path.join(data_dir, "test.arrow")
123123
with ArrowWriter(path=output_test) as writer:
124-
writer.write_row(pa.Table.from_pydict({"col_1": ["bar"] * 10}))
124+
writer.write_table(pa.Table.from_pydict({"col_1": ["bar"] * 10}))
125125
num_examples, num_bytes = writer.finalize()
126126
assert num_examples == 10
127127
assert num_bytes > 0

0 commit comments

Comments
 (0)