Skip to content

Commit b02f546

Browse files
committed
Address comments
1 parent 6f9791d commit b02f546

File tree

5 files changed

+10
-35
lines changed

5 files changed

+10
-35
lines changed

src/datasets/packaged_modules/csv/csv.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,6 @@ def __post_init__(self):
6868
if self.column_names is not None:
6969
self.names = self.column_names
7070

71-
@property
72-
def schema(self):
73-
return self.features.arrow_schema if self.features is not None else None
74-
7571
@property
7672
def read_csv_kwargs(self):
7773
read_csv_kwargs = dict(
@@ -154,7 +150,7 @@ def _split_generators(self, dl_manager):
154150

155151
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
156152
if self.config.features is not None:
157-
schema = self.config.schema
153+
schema = self.config.features.arrow_schema
158154
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
159155
# cheaper cast
160156
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
@@ -164,14 +160,14 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
164160
return pa_table
165161

166162
def _generate_tables(self, files):
167-
schema = self.config.schema
163+
schema = self.config.features.arrow_schema if self.config.features else None
168164
# dtype allows reading an int column as str
169165
dtype = (
170166
{
171167
name: dtype.to_pandas_dtype() if not require_storage_cast(feature) else object
172168
for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values())
173169
}
174-
if schema
170+
if schema is not None
175171
else None
176172
)
177173
for file_idx, file in enumerate(files):

src/datasets/packaged_modules/json/json.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pyarrow.json as paj
88

99
import datasets
10-
from datasets.features.features import require_storage_cast
1110
from datasets.table import table_cast
1211
from datasets.utils.file_utils import readline
1312

@@ -26,10 +25,6 @@ class JsonConfig(datasets.BuilderConfig):
2625
chunksize: int = 10 << 20 # 10MB
2726
newlines_in_values: Optional[bool] = None
2827

29-
@property
30-
def schema(self):
31-
return self.features.arrow_schema if self.features is not None else None
32-
3328

3429
class Json(datasets.ArrowBasedBuilder):
3530
BUILDER_CONFIG_CLASS = JsonConfig
@@ -67,13 +62,9 @@ def _split_generators(self, dl_manager):
6762

6863
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
6964
if self.config.features is not None:
70-
schema = self.config.schema
71-
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
72-
# cheaper cast
73-
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
74-
else:
75-
# more expensive cast; allows str <-> int/float or str to Audio for example
76-
pa_table = table_cast(pa_table, schema)
65+
# more expensive cast to support nested structures with keys in a different order
66+
# allows str <-> int/float or str to Audio for example
67+
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
7768
return pa_table
7869

7970
def _generate_tables(self, files):

src/datasets/packaged_modules/pandas/pandas.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ class PandasConfig(datasets.BuilderConfig):
1515

1616
features: Optional[datasets.Features] = None
1717

18-
@property
19-
def schema(self):
20-
return self.features.arrow_schema if self.features is not None else None
21-
2218

2319
class Pandas(datasets.ArrowBasedBuilder):
2420
BUILDER_CONFIG_CLASS = PandasConfig
@@ -45,7 +41,7 @@ def _split_generators(self, dl_manager):
4541

4642
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
4743
if self.config.features is not None:
48-
schema = self.config.schema
44+
schema = self.config.features.arrow_schema
4945
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
5046
# cheaper cast
5147
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ class ParquetConfig(datasets.BuilderConfig):
2020
columns: Optional[List[str]] = None
2121
features: Optional[datasets.Features] = None
2222

23-
@property
24-
def schema(self):
25-
return self.features.arrow_schema if self.features is not None else None
26-
2723

2824
class Parquet(datasets.ArrowBasedBuilder):
2925
BUILDER_CONFIG_CLASS = ParquetConfig
@@ -50,7 +46,7 @@ def _split_generators(self, dl_manager):
5046

5147
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
5248
if self.config.features is not None:
53-
schema = self.config.schema
49+
schema = self.config.features.arrow_schema
5450
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
5551
# cheaper cast
5652
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
@@ -60,7 +56,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
6056
return pa_table
6157

6258
def _generate_tables(self, files):
63-
schema = pa.schema(self.config.features.type) if self.config.features is not None else None
59+
schema = self.config.features.arrow_schema if self.config.features is not None else None
6460
if self.config.features is not None and self.config.columns is not None:
6561
if sorted(field.name for field in schema) != sorted(self.config.columns):
6662
raise ValueError(

src/datasets/packaged_modules/text/text.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ class TextConfig(datasets.BuilderConfig):
2222
keep_linebreaks: bool = False
2323
sample_by: str = "line"
2424

25-
@property
26-
def schema(self):
27-
return self.features.arrow_schema if self.features is not None else None
28-
2925

3026
class Text(datasets.ArrowBasedBuilder):
3127
BUILDER_CONFIG_CLASS = TextConfig
@@ -58,7 +54,7 @@ def _split_generators(self, dl_manager):
5854

5955
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
6056
if self.config.features is not None:
61-
schema = self.config.schema
57+
schema = self.config.features.arrow_schema
6258
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
6359
# cheaper cast
6460
pa_table = pa_table.cast(schema)

0 commit comments

Comments
 (0)