Skip to content

Commit 341b555

Browse files
authored
Preserve non-input_colums in Dataset.map if input_columns are specified (#4971)
* Preserve non-`input_colums` in `Dataset.map` * Remove unnecessary assert * Address review comments
1 parent e195bc1 commit 341b555

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/datasets/arrow_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2680,7 +2680,7 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
26802680
validate_function_output(processed_inputs, indices)
26812681
if not update_data:
26822682
return None # Nothing to update, let's move on
2683-
if self._format_type is not None:
2683+
if self._format_type is not None or input_columns:
26842684
inputs = self._getitem(
26852685
key=(indices if isinstance(indices, int) else slice(indices[0], indices[-1] + 1)),
26862686
format_type=None,

tests/test_arrow_dataset.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,23 @@ def func(example):
13411341
)
13421342
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])
13431343

1344+
def test_map_input_columns(self, in_memory):
1345+
with tempfile.TemporaryDirectory() as tmp_dir:
1346+
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
1347+
with dset.map(lambda col_1: {"label": col_1 % 2}, input_columns="col_1") as mapped_dset:
1348+
self.assertEqual(mapped_dset[0].keys(), {"col_1", "col_2", "col_3", "label"})
1349+
self.assertEqual(
1350+
mapped_dset.features,
1351+
Features(
1352+
{
1353+
"col_1": Value("int64"),
1354+
"col_2": Value("string"),
1355+
"col_3": Value("bool"),
1356+
"label": Value("int64"),
1357+
}
1358+
),
1359+
)
1360+
13441361
def test_map_remove_columns(self, in_memory):
13451362
with tempfile.TemporaryDirectory() as tmp_dir:
13461363
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:

0 commit comments

Comments
 (0)