Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,57 @@ def rename(columns):
dataset._fingerprint = new_fingerprint
return dataset

@fingerprint_transform(inplace=False)
def rename_columns(self, column_mapping: Dict[str, str], new_fingerprint) -> "Dataset":
"""
Rename several columns in the dataset, and move the features associated to the original columns under
the new column names.

Args:
column_mapping (:obj:`Dict[str, str]`): A mapping of columns to rename to their new names

Returns:
:class:`Dataset`: A copy of the dataset with renamed columns
"""
dataset = copy.deepcopy(self)

extra_columns = set(column_mapping.keys()) - set(dataset.column_names)
if extra_columns:
raise ValueError(
f"Original column names {extra_columns} not in the dataset. "
f"Current columns in the dataset: {dataset._data.column_names}"
)

number_of_duplicates_in_new_columns = len(column_mapping.values()) - len(set(column_mapping.values()))
if number_of_duplicates_in_new_columns != 0:
raise ValueError(
"New column names must all be different, but this column mapping "
f"has {number_of_duplicates_in_new_columns} duplicates"
)

empty_new_columns = [new_col for new_col in column_mapping.values() if not new_col]
if empty_new_columns:
raise ValueError(f"New column names {empty_new_columns} are empty.")

def rename(columns):
return [column_mapping[col] if col in column_mapping else col for col in columns]

new_column_names = rename(self._data.column_names)
if self._format_columns is not None:
dataset._format_columns = rename(self._format_columns)

dataset._info.features = Features(
{
column_mapping[col] if col in column_mapping else col: feature
for col, feature in (self._info.features or {}).items()
}
)

dataset._data = dataset._data.rename_columns(new_column_names)
dataset._data = update_metadata_with_features(dataset._data, self.features)
dataset._fingerprint = new_fingerprint
return dataset

def __len__(self):
"""Number of rows in the dataset."""
return self.num_rows
Expand Down
28 changes: 28 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,34 @@ def test_rename_column(self, in_memory):
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

def test_rename_columns(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
fingerprint = dset._fingerprint
with dset.rename_columns({"col_1": "new_name"}) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)

with dset.rename_columns({"col_1": "new_name", "col_2": "new_name2"}) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["new_name", "new_name2", "col_3"])
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)

# Original column not in dataset
with self.assertRaises(ValueError):
dset.rename_columns({"not_there": "new_name"})

# Empty new name
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": ""})

# Duplicates
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": "new_name", "col_2": "new_name"})

def test_concatenate(self, in_memory):
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
info1 = DatasetInfo(description="Dataset1")
Expand Down