diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index dbdff64953b..250442a0ed8 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -660,7 +660,11 @@ def __init__(self, source: Union["Dataset", "Column"], column_name: str): def __iter__(self) -> Iterator[Any]: if isinstance(self.source, Dataset): - source = self.source._fast_select_column(self.column_name) + if self.source._format_type == "custom": + # the formatting transform may require all columns + source = self.source + else: + source = self.source._fast_select_column(self.column_name) else: source = self.source for example in source: @@ -670,7 +674,12 @@ def __getitem__(self, key: Union[int, str, list[int]]) -> Any: if isinstance(key, str): return Column(self, key) elif isinstance(self.source, Dataset): - return self.source._fast_select_column(self.column_name)[key][self.column_name] + if self.source._format_type == "custom": + # the formatting transform may require all columns + source = self.source + else: + source = self.source._fast_select_column(self.column_name) + return source[key][self.column_name] elif isinstance(key, int): return self.source[key][self.column_name] else: