diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 073e47aefca..341092b9642 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3296,7 +3296,7 @@ class NumExamplesMismatchError(Exception): def validate_function_output(processed_inputs, indices): """Validate output of the map function.""" - if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table)): + if processed_inputs is not None and not isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame)): raise TypeError( f"Provided `function` which is applied to all elements of table returns a variable of type {type(processed_inputs)}. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects." ) @@ -3351,7 +3351,7 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example returned_lazy_dict = False if update_data is None: # Check if the function returns updated examples - update_data = isinstance(processed_inputs, (Mapping, pa.Table)) + update_data = isinstance(processed_inputs, (Mapping, pa.Table, pd.DataFrame)) validate_function_output(processed_inputs, indices) if not update_data: return None # Nothing to update, let's move on @@ -3445,6 +3445,8 @@ def init_buffer_and_writer(): stack.enter_context(writer) if isinstance(example, pa.Table): writer.write_row(example) + elif isinstance(example, pd.DataFrame): + writer.write_row(pa.Table.from_pandas(example)) else: writer.write(example) num_examples_progress_update += 1 @@ -3476,6 +3478,8 @@ def init_buffer_and_writer(): stack.enter_context(writer) if isinstance(batch, pa.Table): writer.write_table(batch) + elif isinstance(batch, pd.DataFrame): + writer.write_table(pa.Table.from_pandas(batch)) else: writer.write_batch(batch) num_examples_progress_update += num_examples_in_batch diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 10d6d59153c..87aeb9124c2 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -510,6 +510,8 @@ def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None): Args: row: the row to add. """ + if len(row) != 1: + raise ValueError(f"Only single-row pyarrow tables are allowed but got table with {len(row)} rows.") self.current_rows.append(row) if writer_batch_size is None: writer_batch_size = self.writer_batch_size diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 3846d43498d..d2cf5273362 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1181,7 +1181,7 @@ def test_map_multiprocessing(self, in_memory): self.assertNotEqual(dset_test._fingerprint, fingerprint) assert_arrow_metadata_are_synced_with_dataset_features(dset_test) - def test_new_features(self, in_memory): + def test_map_new_features(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir) as dset: features = Features({"filename": Value("string"), "label": ClassLabel(names=["positive", "negative"])}) @@ -1388,6 +1388,84 @@ def test_map_caching(self, in_memory): finally: datasets.enable_caching() + def test_map_return_pa_table(self, in_memory): + def func_return_single_row_pa_table(x): + return pa.table({"id": [0], "text": ["a"]}) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir) as dset: + with dset.map(func_return_single_row_pa_table) as dset_test: + self.assertEqual(len(dset_test), 30) + self.assertDictEqual( + dset_test.features, + Features({"id": Value("int64"), "text": Value("string")}), + ) + self.assertEqual(dset_test[0]["id"], 0) + self.assertEqual(dset_test[0]["text"], "a") + + # Batched + def func_return_single_row_pa_table_batched(x): + batch_size = len(x[next(iter(x))]) + return pa.table({"id": [0] * batch_size, "text": ["a"] * batch_size}) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir) as dset: + with dset.map(func_return_single_row_pa_table_batched, batched=True) as dset_test: + self.assertEqual(len(dset_test), 30) + self.assertDictEqual( + dset_test.features, + Features({"id": Value("int64"), "text": Value("string")}), + ) + self.assertEqual(dset_test[0]["id"], 0) + self.assertEqual(dset_test[0]["text"], "a") + + # Error when returning a table with more than one row in the non-batched mode + def func_return_multi_row_pa_table(x): + return pa.table({"id": [0, 1], "text": ["a", "b"]}) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir) as dset: + self.assertRaises(ValueError, dset.map, func_return_multi_row_pa_table) + + def test_map_return_pd_dataframe(self, in_memory): + def func_return_single_row_pd_dataframe(x): + return pd.DataFrame({"id": [0], "text": ["a"]}) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir) as dset: + with dset.map(func_return_single_row_pd_dataframe) as dset_test: + self.assertEqual(len(dset_test), 30) + self.assertDictEqual( + dset_test.features, + Features({"id": Value("int64"), "text": Value("string")}), + ) + self.assertEqual(dset_test[0]["id"], 0) + self.assertEqual(dset_test[0]["text"], "a") + + # Batched + def func_return_single_row_pd_dataframe_batched(x): + batch_size = len(x[next(iter(x))]) + return pd.DataFrame({"id": [0] * batch_size, "text": ["a"] * batch_size}) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir) as dset: + with dset.map(func_return_single_row_pd_dataframe_batched, batched=True) as dset_test: + self.assertEqual(len(dset_test), 30) + self.assertDictEqual( + dset_test.features, + Features({"id": Value("int64"), "text": Value("string")}), + ) + self.assertEqual(dset_test[0]["id"], 0) + self.assertEqual(dset_test[0]["text"], "a") + + # Error when returning a table with more than one row in the non-batched mode + def func_return_multi_row_pd_dataframe(x): + return pd.DataFrame({"id": [0, 1], "text": ["a", "b"]}) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir) as dset: + self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe) + @require_torch def test_map_torch(self, in_memory): import torch diff --git a/tests/test_load.py b/tests/test_load.py index a4202e523a9..8e5ff330c51 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -115,13 +115,13 @@ def data_dir_with_arrow(tmp_path): data_dir.mkdir() output_train = os.path.join(data_dir, "train.arrow") with ArrowWriter(path=output_train) as writer: - writer.write_row(pa.Table.from_pydict({"col_1": ["foo"] * 10})) + writer.write_table(pa.Table.from_pydict({"col_1": ["foo"] * 10})) num_examples, num_bytes = writer.finalize() assert num_examples == 10 assert num_bytes > 0 output_test = os.path.join(data_dir, "test.arrow") with ArrowWriter(path=output_test) as writer: - writer.write_row(pa.Table.from_pydict({"col_1": ["bar"] * 10})) + writer.write_table(pa.Table.from_pydict({"col_1": ["bar"] * 10})) num_examples, num_bytes = writer.finalize() assert num_examples == 10 assert num_bytes > 0