Skip to content

Commit 6ddd1bd

Browse files
committed
Add table_cast to loaders
1 parent 9f93c4b commit 6ddd1bd

File tree

5 files changed

+114
-33
lines changed

5 files changed

+114
-33
lines changed

src/datasets/packaged_modules/csv/csv.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import datasets
99
import datasets.config
10+
from datasets.features.features import require_cast_storage
11+
from datasets.table import table_cast
1012

1113

1214
logger = datasets.utils.logging.get_logger(__name__)
@@ -66,6 +68,10 @@ def __post_init__(self):
6668
if self.column_names is not None:
6769
self.names = self.column_names
6870

71+
@property
72+
def schema(self):
73+
return self.features.arrow_schema if self.features is not None else None
74+
6975
@property
7076
def read_csv_kwargs(self):
7177
read_csv_kwargs = dict(
@@ -146,19 +152,37 @@ def _split_generators(self, dl_manager):
146152
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
147153
return splits
148154

155+
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
156+
if self.config.features is not None:
157+
schema = self.config.schema
158+
if all(not require_cast_storage(feature) for feature in self.config.features.values()):
159+
# cheaper cast
160+
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
161+
else:
162+
# more expensive cast; allows str <-> int/float or str to Audio for example
163+
pa_table = table_cast(pa_table, schema)
164+
return pa_table
165+
149166
def _generate_tables(self, files):
150-
schema = pa.schema(self.config.features.type) if self.config.features is not None else None
167+
schema = self.config.schema
151168
# dtype allows reading an int column as str
152-
dtype = {name: dtype.to_pandas_dtype() for name, dtype in zip(schema.names, schema.types)} if schema else None
169+
dtype = (
170+
{
171+
name: dtype.to_pandas_dtype() if not require_cast_storage(feature) else object
172+
for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values())
173+
}
174+
if schema
175+
else None
176+
)
153177
for file_idx, file in enumerate(files):
154178
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
155179
try:
156180
for batch_idx, df in enumerate(csv_file_reader):
157-
pa_table = pa.Table.from_pandas(df, schema=schema)
181+
pa_table = pa.Table.from_pandas(df)
158182
# Uncomment for debugging (will print the Arrow table size and elements)
159183
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
160184
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
161-
yield (file_idx, batch_idx), pa_table
185+
yield (file_idx, batch_idx), self._cast_table(pa_table)
162186
except ValueError as e:
163187
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
164188
raise

src/datasets/packaged_modules/json/json.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pyarrow.json as paj
88

99
import datasets
10+
from datasets.features.features import require_cast_storage
1011
from datasets.table import table_cast
1112
from datasets.utils.file_utils import readline
1213

@@ -64,21 +65,15 @@ def _split_generators(self, dl_manager):
6465
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
6566
return splits
6667

67-
def _cast_classlabels(self, pa_table: pa.Table) -> pa.Table:
68-
if self.config.features:
69-
# Encode column if ClassLabel
70-
for i, col in enumerate(self.config.features.keys()):
71-
if isinstance(self.config.features[col], datasets.ClassLabel):
72-
if pa_table[col].type == pa.string():
73-
pa_table = pa_table.set_column(
74-
i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])]
75-
)
76-
elif pa_table[col].type != self.config.schema.field(col).type:
77-
raise ValueError(
78-
f"Field '{col}' from the JSON data of type {pa_table[col].type} is not compatible with ClassLabel. Compatible types are int64 and string."
79-
)
80-
# Cast allows str <-> int/float or str to Audio for example
81-
pa_table = table_cast(pa_table, self.config.schema)
68+
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
69+
if self.config.features is not None:
70+
schema = self.config.schema
71+
if all(not require_cast_storage(feature) for feature in self.config.features.values()):
72+
# cheaper cast
73+
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
74+
else:
75+
# more expensive cast; allows str <-> int/float or str to Audio for example
76+
pa_table = table_cast(pa_table, schema)
8277
return pa_table
8378

8479
def _generate_tables(self, files):
@@ -98,7 +93,7 @@ def _generate_tables(self, files):
9893
else:
9994
mapping = dataset
10095
pa_table = pa.Table.from_pydict(mapping=mapping)
101-
yield file_idx, self._cast_classlabels(pa_table)
96+
yield file_idx, self._cast_table(pa_table)
10297

10398
# If the file has one json object per line
10499
else:
@@ -153,5 +148,5 @@ def _generate_tables(self, files):
153148
# Uncomment for debugging (will print the Arrow table size and elements)
154149
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
155150
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
156-
yield (file_idx, batch_idx), self._cast_classlabels(pa_table)
151+
yield (file_idx, batch_idx), self._cast_table(pa_table)
157152
batch_idx += 1

src/datasets/packaged_modules/pandas/pandas.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
14
import pandas as pd
25
import pyarrow as pa
36

47
import datasets
8+
from datasets.features.features import require_cast_storage
9+
from datasets.table import table_cast
10+
11+
12+
@dataclass
13+
class PandasConfig(datasets.BuilderConfig):
14+
"""BuilderConfig for Pandas."""
15+
16+
features: Optional[datasets.Features] = None
17+
18+
@property
19+
def schema(self):
20+
return self.features.arrow_schema if self.features is not None else None
521

622

723
class Pandas(datasets.ArrowBasedBuilder):
24+
BUILDER_CONFIG_CLASS = PandasConfig
25+
826
def _info(self):
9-
return datasets.DatasetInfo()
27+
return datasets.DatasetInfo(features=self.config.features)
1028

1129
def _split_generators(self, dl_manager):
1230
"""We handle string, list and dicts in datafiles"""
@@ -25,8 +43,19 @@ def _split_generators(self, dl_manager):
2543
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
2644
return splits
2745

46+
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
47+
if self.config.features is not None:
48+
schema = self.config.schema
49+
if all(not require_cast_storage(feature) for feature in self.config.features.values()):
50+
# cheaper cast
51+
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
52+
else:
53+
# more expensive cast; allows str <-> int/float or str to Audio for example
54+
pa_table = table_cast(pa_table, schema)
55+
return pa_table
56+
2857
def _generate_tables(self, files):
2958
for i, file in enumerate(files):
3059
with open(file, "rb") as f:
3160
pa_table = pa.Table.from_pandas(pd.read_pickle(f))
32-
yield i, pa_table
61+
yield i, self._cast_table(pa_table)

src/datasets/packaged_modules/parquet/parquet.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pyarrow.parquet as pq
66

77
import datasets
8+
from datasets.features.features import require_cast_storage
9+
from datasets.table import table_cast
810

911

1012
logger = datasets.utils.logging.get_logger(__name__)
@@ -18,6 +20,10 @@ class ParquetConfig(datasets.BuilderConfig):
1820
columns: Optional[List[str]] = None
1921
features: Optional[datasets.Features] = None
2022

23+
@property
24+
def schema(self):
25+
return self.features.arrow_schema if self.features is not None else None
26+
2127

2228
class Parquet(datasets.ArrowBasedBuilder):
2329
BUILDER_CONFIG_CLASS = ParquetConfig
@@ -42,6 +48,17 @@ def _split_generators(self, dl_manager):
4248
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
4349
return splits
4450

51+
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
52+
if self.config.features is not None:
53+
schema = self.config.schema
54+
if all(not require_cast_storage(feature) for feature in self.config.features.values()):
55+
# cheaper cast
56+
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
57+
else:
58+
# more expensive cast; allows str <-> int/float or str to Audio for example
59+
pa_table = table_cast(pa_table, schema)
60+
return pa_table
61+
4562
def _generate_tables(self, files):
4663
schema = pa.schema(self.config.features.type) if self.config.features is not None else None
4764
if self.config.features is not None and self.config.columns is not None:
@@ -57,12 +74,10 @@ def _generate_tables(self, files):
5774
parquet_file.iter_batches(batch_size=self.config.batch_size, columns=self.config.columns)
5875
):
5976
pa_table = pa.Table.from_batches([record_batch])
60-
if self.config.features is not None:
61-
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
6277
# Uncomment for debugging (will print the Arrow table size and elements)
6378
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
6479
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
65-
yield f"{file_idx}_{batch_idx}", pa_table
80+
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
6681
except ValueError as e:
6782
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
6883
raise

src/datasets/packaged_modules/text/text.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pyarrow as pa
66

77
import datasets
8+
from datasets.features.features import require_cast_storage
9+
from datasets.table import table_cast
810

911

1012
logger = datasets.utils.logging.get_logger(__name__)
@@ -20,6 +22,10 @@ class TextConfig(datasets.BuilderConfig):
2022
keep_linebreaks: bool = False
2123
sample_by: str = "line"
2224

25+
@property
26+
def schema(self):
27+
return self.features.arrow_schema if self.features is not None else None
28+
2329

2430
class Text(datasets.ArrowBasedBuilder):
2531
BUILDER_CONFIG_CLASS = TextConfig
@@ -50,8 +56,20 @@ def _split_generators(self, dl_manager):
5056
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
5157
return splits
5258

59+
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
60+
if self.config.features is not None:
61+
schema = self.config.schema
62+
if all(not require_cast_storage(feature) for feature in self.config.features.values()):
63+
# cheaper cast
64+
pa_table = pa_table.cast(schema)
65+
else:
66+
# more expensive cast; allows str <-> int/float or str to Audio for example
67+
pa_table = table_cast(pa_table, schema)
68+
return pa_table
69+
else:
70+
return pa_table.cast(pa.schema({"text": pa.string()}))
71+
5372
def _generate_tables(self, files):
54-
schema = pa.schema(self.config.features.type if self.config.features is not None else {"text": pa.string()})
5573
for file_idx, file in enumerate(files):
5674
# open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n"
5775
with open(file, encoding=self.config.encoding) as f:
@@ -66,11 +84,11 @@ def _generate_tables(self, files):
6684
batch = StringIO(batch).readlines()
6785
if not self.config.keep_linebreaks:
6886
batch = [line.rstrip("\n") for line in batch]
69-
pa_table = pa.Table.from_arrays([pa.array(batch)], schema=schema)
87+
pa_table = pa.Table.from_arrays([pa.array(batch)], names=["text"])
7088
# Uncomment for debugging (will print the Arrow table size and elements)
7189
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
7290
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
73-
yield (file_idx, batch_idx), pa_table
91+
yield (file_idx, batch_idx), self._cast_table(pa_table)
7492
batch_idx += 1
7593
elif self.config.sample_by == "paragraph":
7694
batch_idx = 0
@@ -82,15 +100,15 @@ def _generate_tables(self, files):
82100
batch += f.readline() # finish current line
83101
batch = batch.split("\n\n")
84102
pa_table = pa.Table.from_arrays(
85-
[pa.array([example for example in batch[:-1] if example])], schema=schema
103+
[pa.array([example for example in batch[:-1] if example])], names=["text"]
86104
)
87105
# Uncomment for debugging (will print the Arrow table size and elements)
88106
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
89107
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
90-
yield (file_idx, batch_idx), pa_table
108+
yield (file_idx, batch_idx), self._cast_table(pa_table)
91109
batch_idx += 1
92110
batch = batch[-1]
93111
elif self.config.sample_by == "document":
94112
text = f.read()
95-
pa_table = pa.Table.from_arrays([pa.array([text])], schema=schema)
96-
yield file_idx, pa_table
113+
pa_table = pa.Table.from_arrays([pa.array([text])], names=["text"])
114+
yield file_idx, self._cast_table(pa_table)

0 commit comments

Comments
 (0)