Skip to content

Commit 597a97a

Browse files
committed
Address comments
1 parent b02f546 commit 597a97a

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

src/datasets/packaged_modules/pandas/pandas.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pyarrow as pa
66

77
import datasets
8-
from datasets.features.features import require_storage_cast
98
from datasets.table import table_cast
109

1110

@@ -41,13 +40,9 @@ def _split_generators(self, dl_manager):
4140

4241
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
4342
if self.config.features is not None:
44-
schema = self.config.features.arrow_schema
45-
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
46-
# cheaper cast
47-
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
48-
else:
49-
# more expensive cast; allows str <-> int/float or str to Audio for example
50-
pa_table = table_cast(pa_table, schema)
43+
# more expensive cast to support nested features with keys in a different order
44+
# allows str <-> int/float or str to Audio for example
45+
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
5146
return pa_table
5247

5348
def _generate_tables(self, files):

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pyarrow.parquet as pq
66

77
import datasets
8-
from datasets.features.features import require_storage_cast
98
from datasets.table import table_cast
109

1110

@@ -46,13 +45,9 @@ def _split_generators(self, dl_manager):
4645

4746
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
4847
if self.config.features is not None:
49-
schema = self.config.features.arrow_schema
50-
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
51-
# cheaper cast
52-
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
53-
else:
54-
# more expensive cast; allows str <-> int/float or str to Audio for example
55-
pa_table = table_cast(pa_table, schema)
48+
# more expensive cast to support nested features with keys in a different order
49+
# allows str <-> int/float or str to Audio for example
50+
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
5651
return pa_table
5752

5853
def _generate_tables(self, files):

0 commit comments

Comments
 (0)