diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f7062835d41..5c20c907979 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2155,12 +2155,12 @@ def remove_columns(self, column_names: Union[str, List[str]], new_fingerprint: O if isinstance(column_names, str): column_names = [column_names] - for column_name in column_names: - if column_name not in dataset._data.column_names: - raise ValueError( - f"Column name {column_name} not in the dataset. " - f"Current columns in the dataset: {dataset._data.column_names}" - ) + missing_columns = set(column_names) - set(self._data.column_names) + if missing_columns: + raise ValueError( + f"Column name {list(missing_columns)} not in the dataset. " + f"Current columns in the dataset: {dataset._data.column_names}" + ) for column_name in column_names: del dataset._info.features[column_name] @@ -2339,13 +2339,13 @@ def select_columns(self, column_names: Union[str, List[str]], new_fingerprint: O if isinstance(column_names, str): column_names = [column_names] - for column_name in column_names: - if column_name not in self._data.column_names: - raise ValueError( - f"Column name {column_name} not in the " - "dataset. Current columns in the dataset: " - f"{self._data.column_names}." - ) + missing_columns = set(column_names) - set(self._data.column_names) + if missing_columns: + raise ValueError( + f"Column name {list(missing_columns)} not in the " + "dataset. Current columns in the dataset: " + f"{self._data.column_names}." + ) dataset = copy.deepcopy(self) dataset._data = dataset._data.select(column_names) @@ -2534,10 +2534,12 @@ def set_format( columns = [columns] if isinstance(columns, tuple): columns = list(columns) - if columns is not None and any(col not in self._data.column_names for col in columns): - raise ValueError( - f"Columns {list(filter(lambda col: col not in self._data.column_names, columns))} not in the dataset. Current columns in the dataset: {self._data.column_names}" - ) + if columns is not None: + missing_columns = set(columns) - set(self._data.column_names) + if missing_columns: + raise ValueError( + f"Columns {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}" + ) if columns is not None: columns = columns.copy() # Ensures modifications made to the list after this call don't cause bugs @@ -3008,19 +3010,21 @@ def map( input_columns = [input_columns] if input_columns is not None: - for input_column in input_columns: - if input_column not in self._data.column_names: - raise ValueError( - f"Input column {input_column} not in the dataset. Current columns in the dataset: {self._data.column_names}" - ) + missing_columns = set(input_columns) - set(self._data.column_names) + if missing_columns: + raise ValueError( + f"Input column {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}" + ) if isinstance(remove_columns, str): remove_columns = [remove_columns] - if remove_columns is not None and any(col not in self._data.column_names for col in remove_columns): - raise ValueError( - f"Column to remove {list(filter(lambda col: col not in self._data.column_names, remove_columns))} not in the dataset. Current columns in the dataset: {self._data.column_names}" - ) + if remove_columns is not None: + missing_columns = set(remove_columns) - set(self._data.column_names) + if missing_columns: + raise ValueError( + f"Column to remove {list(missing_columns)} not in the dataset. Current columns in the dataset: {self._data.column_names}" + ) load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled() diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 5d6d8141f6d..abbac174e83 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -428,12 +428,18 @@ def write_examples_on_file(self): return # order the columns properly - cols = ( - [col for col in self.schema.names if col in self.current_examples[0][0]] - + [col for col in self.current_examples[0][0].keys() if col not in self.schema.names] - if self.schema - else self.current_examples[0][0].keys() - ) + if self.schema: + schema_cols = set(self.schema.names) + common_cols, extra_cols = [], [] + for col in self.current_examples[0][0]: + if col in schema_cols: + common_cols.append(col) + else: + extra_cols.append(col) + cols = common_cols + extra_cols + else: + cols = list(self.current_examples[0][0]) + batch_examples = {} for col in cols: # We use row[0][col] since current_examples contains (example, key) tuples. @@ -538,12 +544,17 @@ def write_batch( try_features = self._features if self.pa_writer is None and self.update_features else None arrays = [] inferred_features = Features() - cols = ( - [col for col in self.schema.names if col in batch_examples] - + [col for col in batch_examples.keys() if col not in self.schema.names] - if self.schema - else batch_examples.keys() - ) + if self.schema: + schema_cols = set(self.schema.names) + common_cols, extra_cols = [], [] + for col in batch_examples: + if col in schema_cols: + common_cols.append(col) + else: + extra_cols.append(col) + cols = common_cols + extra_cols + else: + cols = list(batch_examples) for col in cols: col_values = batch_examples[col] col_type = features[col] if features else None diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 46b968fe333..f508a0b5271 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2067,13 +2067,13 @@ def select_columns(self, column_names: Union[str, List[str]]) -> "IterableDatase if self._info: info = copy.deepcopy(self._info) if self._info.features is not None: - for column_name in column_names: - if column_name not in self._info.features: - raise ValueError( - f"Column name {column_name} not in the " - "dataset. Columns in the dataset: " - f"{list(self._info.features.keys())}." - ) + missing_columns = set(column_names) - set(self._info.features.keys()) + if missing_columns: + raise ValueError( + f"Column name {list(missing_columns)} not in the " + "dataset. Columns in the dataset: " + f"{list(self._info.features.keys())}." + ) info.features = Features({c: info.features[c] for c in column_names}) # check that it's still valid, especially with regard to task templates try: