Skip to content
56 changes: 30 additions & 26 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,12 +2155,12 @@ def remove_columns(self, column_names: Union[str, List[str]], new_fingerprint: O
if isinstance(column_names, str):
column_names = [column_names]

for column_name in column_names:
if column_name not in dataset._data.column_names:
raise ValueError(
f"Column name {column_name} not in the dataset. "
f"Current columns in the dataset: {dataset._data.column_names}"
)
missing_columns = set(column_names) - set(self._data.column_names)
if missing_columns:
raise ValueError(
f"Column name {list(missing_columns)} not in the dataset. "
f"Current columns in the dataset: {dataset._data.column_names}"
)

for column_name in column_names:
del dataset._info.features[column_name]
Expand Down Expand Up @@ -2339,13 +2339,13 @@ def select_columns(self, column_names: Union[str, List[str]], new_fingerprint: O
if isinstance(column_names, str):
column_names = [column_names]

for column_name in column_names:
if column_name not in self._data.column_names:
raise ValueError(
f"Column name {column_name} not in the "
"dataset. Current columns in the dataset: "
f"{self._data.column_names}."
)
missing_columns = set(column_names) - set(self._data.column_names)
if missing_columns:
raise ValueError(
f"Column name {list(missing_columns)} not in the "
"dataset. Current columns in the dataset: "
f"{self._data.column_names}."
)

dataset = copy.deepcopy(self)
dataset._data = dataset._data.select(column_names)
Expand Down Expand Up @@ -2534,10 +2534,12 @@ def set_format(
columns = [columns]
if isinstance(columns, tuple):
columns = list(columns)
if columns is not None and any(col not in self._data.column_names for col in columns):
raise ValueError(
f"Columns {list(filter(lambda col: col not in self._data.column_names, columns))} not in the dataset. Current columns in the dataset: {self._data.column_names}"
)
if columns is not None:
missing_columns = set(columns) - set(self._data.column_names)
if missing_columns:
raise ValueError(
f"Columns {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}"
)
if columns is not None:
columns = columns.copy() # Ensures modifications made to the list after this call don't cause bugs

Expand Down Expand Up @@ -3008,19 +3010,21 @@ def map(
input_columns = [input_columns]

if input_columns is not None:
for input_column in input_columns:
if input_column not in self._data.column_names:
raise ValueError(
f"Input column {input_column} not in the dataset. Current columns in the dataset: {self._data.column_names}"
)
missing_columns = set(input_columns) - set(self._data.column_names)
if missing_columns:
raise ValueError(
f"Input column {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}"
)

if isinstance(remove_columns, str):
remove_columns = [remove_columns]

if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns):
raise ValueError(
f"Column to remove {list(filter(lambda col: col not in self._data.column_names, remove_columns))} not in the dataset. Current columns in the dataset: {self._data.column_names}"
)
if remove_columns is not None:
missing_columns = set(remove_columns) - set(self._data.column_names)
if missing_columns:
raise ValueError(
f"Column to remove {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}"
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

Expand Down
35 changes: 23 additions & 12 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,18 @@ def write_examples_on_file(self):
return

# order the columns properly
cols = (
[col for col in self.schema.names if col in self.current_examples[0][0]]
+ [col for col in self.current_examples[0][0].keys() if col not in self.schema.names]
if self.schema
else self.current_examples[0][0].keys()
)
if self.schema:
schema_cols = set(self.schema.names)
common_cols, extra_cols = [], []
for col in self.current_examples[0][0]:
if col in schema_cols:
common_cols.append(col)
else:
extra_cols.append(col)
cols = common_cols + extra_cols
else:
cols = list(self.current_examples[0][0])

batch_examples = {}
for col in cols:
# We use row[0][col] since current_examples contains (example, key) tuples.
Expand Down Expand Up @@ -538,12 +544,17 @@ def write_batch(
try_features = self._features if self.pa_writer is None and self.update_features else None
arrays = []
inferred_features = Features()
cols = (
[col for col in self.schema.names if col in batch_examples]
+ [col for col in batch_examples.keys() if col not in self.schema.names]
if self.schema
else batch_examples.keys()
)
if self.schema:
schema_cols = set(self.schema.names)
common_cols, extra_cols = [], []
for col in batch_examples:
if col in schema_cols:
common_cols.append(col)
else:
extra_cols.append(col)
cols = common_cols + extra_cols
else:
cols = list(batch_examples)
for col in cols:
col_values = batch_examples[col]
col_type = features[col] if features else None
Expand Down
14 changes: 7 additions & 7 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,13 +2067,13 @@ def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase
if self._info:
info = copy.deepcopy(self._info)
if self._info.features is not None:
for column_name in column_names:
if column_name not in self._info.features:
raise ValueError(
f"Column name {column_name} not in the "
"dataset. Columns in the dataset: "
f"{list(self._info.features.keys())}."
)
missing_columns = set(column_names) - set(self._info.features.keys())
if missing_columns:
raise ValueError(
f"Column name {list(missing_columns)} not in the "
"dataset. Columns in the dataset: "
f"{list(self._info.features.keys())}."
)
info.features = Features({c: info.features[c] for c in column_names})
# check that it's still valid, especially with regard to task templates
try:
Expand Down