Skip to content

Commit 8d828b9

Browse files
authored
Change features vs schema logic (#423)
* Change features vs schema logic * test output features of dataset transforms * style * fix error msg in SplitInfo + make serializable * fix dictionary_encode_column
1 parent e16f79b commit 8d828b9

File tree

7 files changed

+149
-123
lines changed

7 files changed

+149
-123
lines changed

src/nlp/arrow_dataset.py

Lines changed: 31 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .info import DatasetInfo
3939
from .search import IndexableMixin
4040
from .splits import NamedSplit
41-
from .utils import map_all_sequences_to_lists, map_nested
41+
from .utils import map_nested
4242

4343

4444
logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ class DatasetInfoMixin(object):
4949
at the base level of the Dataset for easy access.
5050
"""
5151

52-
def __init__(self, info: Optional[DatasetInfo], split: Optional[NamedSplit]):
52+
def __init__(self, info: DatasetInfo, split: Optional[NamedSplit]):
5353
self._info = info
5454
self._split = split
5555

@@ -92,7 +92,7 @@ def download_size(self) -> Optional[int]:
9292
return self._info.download_size
9393

9494
@property
95-
def features(self):
95+
def features(self) -> Features:
9696
return self._info.features
9797

9898
@property
@@ -131,6 +131,7 @@ def __init__(
131131
info: Optional[DatasetInfo] = None,
132132
split: Optional[NamedSplit] = None,
133133
):
134+
info = info.copy() if info is not None else DatasetInfo()
134135
DatasetInfoMixin.__init__(self, info=info, split=split)
135136
IndexableMixin.__init__(self)
136137
self._data: pa.Table = arrow_table
@@ -139,6 +140,15 @@ def __init__(
139140
self._format_kwargs: dict = {}
140141
self._format_columns: Optional[list] = None
141142
self._output_all_columns: bool = False
143+
inferred_features = Features.from_arrow_schema(arrow_table.schema)
144+
if self.info.features is not None:
145+
if self.info.features.type != inferred_features.type:
146+
self.info.features = inferred_features
147+
else:
148+
pass # keep the original features
149+
else:
150+
self.info.features = inferred_features
151+
assert self.features is not None, "Features can't be None in a Dataset object"
142152

143153
@classmethod
144154
def from_file(
@@ -177,7 +187,7 @@ def from_pandas(
177187
178188
Be aware that Series of the `object` dtype don't carry enough information to always lead to a meaningful Arrow type. In the case that
179189
we cannot infer a type, e.g. because the DataFrame is of length 0 or the Series only contains None/nan objects, the type is set to
180-
null. This behavior can be avoided by constructing an explicit schema and passing it to this function.
190+
null. This behavior can be avoided by constructing explicit features and passing it to this function.
181191
182192
Args:
183193
df (:obj:``pandas.DataFrame``): the dataframe that contains the dataset.
@@ -186,21 +196,16 @@ def from_pandas(
186196
description, citation, etc.
187197
split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
188198
"""
189-
if info is None:
190-
info = DatasetInfo()
191-
if info.features is None:
192-
info.features = features
193-
elif info.features != features and features is not None:
199+
if info is not None and features is not None and info.features != features:
194200
raise ValueError(
195201
"Features specified in `features` and `info.features` can't be different:\n{}\n{}".format(
196202
features, info.features
197203
)
198204
)
205+
features = features if features is not None else info.feature if info is not None else None
199206
pa_table: pa.Table = pa.Table.from_pandas(
200-
df=df, schema=pa.schema(info.features.type) if info.features is not None else None
207+
df=df, schema=pa.schema(features.type) if features is not None else None
201208
)
202-
if info.features is None:
203-
info.features = Features.from_arrow_schema(pa_table.schema)
204209
return cls(pa_table, info=info, split=split)
205210

206211
@classmethod
@@ -221,21 +226,16 @@ def from_dict(
221226
description, citation, etc.
222227
split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
223228
"""
224-
if info is None:
225-
info = DatasetInfo()
226-
if info.features is None:
227-
info.features = features
228-
elif info.features != features and features is not None:
229+
if info is not None and features is not None and info.features != features:
229230
raise ValueError(
230231
"Features specified in `features` and `info.features` can't be different:\n{}\n{}".format(
231232
features, info.features
232233
)
233234
)
235+
features = features if features is not None else info.feature if info is not None else None
234236
pa_table: pa.Table = pa.Table.from_pydict(
235-
mapping=mapping, schema=pa.schema(info.features.type) if info.features is not None else None
237+
mapping=mapping, schema=pa.schema(features.type) if features is not None else None
236238
)
237-
if info.features is None:
238-
info.features = Features.from_arrow_schema(pa_table.schema)
239239
return cls(pa_table, info=info, split=split)
240240

241241
@property
@@ -277,14 +277,6 @@ def column_names(self) -> List[str]:
277277
"""Names of the columns in the dataset. """
278278
return self._data.column_names
279279

280-
@property
281-
def schema(self) -> pa.Schema:
282-
"""The Arrow schema of the Apache Arrow table backing the dataset.
283-
You probably don't need to access directly this and can rather use
284-
:func:`nlp.Dataset.features` to inspect the dataset features.
285-
"""
286-
return self._data.schema
287-
288280
@property
289281
def shape(self):
290282
"""Shape of the dataset (number of columns, number of rows)."""
@@ -340,6 +332,7 @@ def dictionary_encode_column(self, column: str):
340332
casted_field = pa.field(field.name, pa.dictionary(pa.int32(), field.type), nullable=False)
341333
casted_schema.set(field_index, casted_field)
342334
self._data = self._data.cast(casted_schema)
335+
self.info.features = Features.from_arrow_schema(self._data.schema)
343336

344337
def flatten(self, max_depth=16):
345338
""" Flatten the Table.
@@ -352,7 +345,7 @@ def flatten(self, max_depth=16):
352345
else:
353346
break
354347
if self.info is not None:
355-
self.info.features = Features.from_arrow_schema(self.schema)
348+
self.info.features = Features.from_arrow_schema(self._data.schema)
356349
logger.info(
357350
"Flattened dataset from depth {} to depth {}.".format(depth, 1 if depth + 1 < max_depth else "unknown")
358351
)
@@ -380,8 +373,7 @@ def __iter__(self):
380373
)
381374

382375
def __repr__(self):
383-
schema_str = dict((a, str(b)) for a, b in zip(self._data.schema.names, self._data.schema.types))
384-
return f"Dataset(schema: {schema_str}, num_rows: {self.num_rows})"
376+
return f"Dataset(features: {self.features}, num_rows: {self.num_rows})"
385377

386378
@property
387379
def format(self):
@@ -685,7 +677,7 @@ def map(
685677
load_from_cache_file: bool = True,
686678
cache_file_name: Optional[str] = None,
687679
writer_batch_size: Optional[int] = 1000,
688-
arrow_schema: Optional[pa.Schema] = None,
680+
features: Optional[Features] = None,
689681
disable_nullable: bool = True,
690682
verbose: bool = True,
691683
):
@@ -712,7 +704,7 @@ def map(
712704
results of the computation instead of the automatically generated cache file name.
713705
`writer_batch_size` (`int`, default: `1000`): Number of rows per write operation for the cache file writer.
714706
Higher value gives smaller cache files, lower value consume less temporary memory while running `.map()`.
715-
`arrow_schema` (`Optional[pa.Schema]`, default: `None`): Use a specific Apache Arrow Schema to store the cache file
707+
`features` (`Optional[nlp.Features]`, default: `None`): Use a specific Features to store the cache file
716708
instead of the automatically generated one.
717709
`disable_nullable` (`bool`, default: `True`): Allow null values in the table.
718710
`verbose` (`bool`, default: `True`): Set to `False` to deactivate the tqdm progress bar and informations.
@@ -792,18 +784,6 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
792784
inputs.update(processed_inputs)
793785
return inputs
794786

795-
# Find the output schema if none is given
796-
test_inputs = self[:2] if batched else self[0]
797-
test_indices = [0, 1] if batched else 0
798-
test_output = apply_function_on_filtered_inputs(test_inputs, test_indices)
799-
if arrow_schema is None and update_data:
800-
if not batched:
801-
test_output = self._nest(test_output)
802-
test_output = map_all_sequences_to_lists(test_output)
803-
arrow_schema = pa.Table.from_pydict(test_output).schema
804-
if disable_nullable:
805-
arrow_schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in arrow_schema)
806-
807787
# Check if we've already cached this computation (indexed by a hash)
808788
if self._data_files and update_data:
809789
if cache_file_name is None:
@@ -817,7 +797,7 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
817797
"load_from_cache_file": load_from_cache_file,
818798
"cache_file_name": cache_file_name,
819799
"writer_batch_size": writer_batch_size,
820-
"arrow_schema": arrow_schema,
800+
"features": features,
821801
"disable_nullable": disable_nullable,
822802
}
823803
cache_file_name = self._get_cache_file_path(function, cache_kwargs)
@@ -830,12 +810,12 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
830810
if update_data:
831811
if keep_in_memory or not self._data_files:
832812
buf_writer = pa.BufferOutputStream()
833-
writer = ArrowWriter(schema=arrow_schema, stream=buf_writer, writer_batch_size=writer_batch_size)
813+
writer = ArrowWriter(features=features, stream=buf_writer, writer_batch_size=writer_batch_size)
834814
else:
835815
buf_writer = None
836816
if verbose:
837817
logger.info("Caching processed dataset at %s", cache_file_name)
838-
writer = ArrowWriter(schema=arrow_schema, path=cache_file_name, writer_batch_size=writer_batch_size)
818+
writer = ArrowWriter(features=features, path=cache_file_name, writer_batch_size=writer_batch_size)
839819

840820
# Loop over single examples or batches and write to buffer/file if examples are to be updated
841821
if not batched:
@@ -928,15 +908,8 @@ def map_function(batch, *args):
928908

929909
return result
930910

931-
# to avoid errors with the arrow_schema we define it here
932-
test_inputs = self[:2]
933-
if "remove_columns" in kwargs:
934-
test_inputs = {key: test_inputs[key] for key in (test_inputs.keys() - kwargs["remove_columns"])}
935-
test_inputs = map_all_sequences_to_lists(test_inputs)
936-
arrow_schema = pa.Table.from_pydict(test_inputs).schema
937-
938911
# return map function
939-
return self.map(map_function, batched=True, with_indices=with_indices, arrow_schema=arrow_schema, **kwargs)
912+
return self.map(map_function, batched=True, with_indices=with_indices, features=self.features, **kwargs)
940913

941914
def select(
942915
self,
@@ -991,12 +964,12 @@ def select(
991964
# Prepare output buffer and batched writer in memory or on file if we update the table
992965
if keep_in_memory or not self._data_files:
993966
buf_writer = pa.BufferOutputStream()
994-
writer = ArrowWriter(schema=self.schema, stream=buf_writer, writer_batch_size=writer_batch_size)
967+
writer = ArrowWriter(features=self.features, stream=buf_writer, writer_batch_size=writer_batch_size)
995968
else:
996969
buf_writer = None
997970
if verbose:
998971
logger.info("Caching processed dataset at %s", cache_file_name)
999-
writer = ArrowWriter(schema=self.schema, path=cache_file_name, writer_batch_size=writer_batch_size)
972+
writer = ArrowWriter(features=self.features, path=cache_file_name, writer_batch_size=writer_batch_size)
1000973

1001974
# Loop over batches and write to buffer/file if examples are to be updated
1002975
for i in tqdm(range(0, len(indices), reader_batch_size), disable=not verbose):

src/nlp/arrow_writer.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import pyarrow as pa
2525

26+
from .features import Features
2627
from .utils.file_utils import HF_DATASETS_CACHE, hash_url_to_filename
2728
from .utils.py_utils import map_all_sequences_to_lists
2829

@@ -42,27 +43,35 @@ def __init__(
4243
self,
4344
data_type: Optional[pa.DataType] = None,
4445
schema: Optional[pa.Schema] = None,
46+
features: Optional[Features] = None,
4547
path: Optional[str] = None,
4648
stream: Optional[pa.NativeFile] = None,
4749
writer_batch_size: Optional[int] = None,
4850
disable_nullable: bool = True,
4951
):
5052
if path is None and stream is None:
5153
raise ValueError("At least one of path and stream must be provided.")
52-
53-
if data_type is not None:
54+
if features is not None:
55+
self._features = features
56+
self._schema = pa.schema(features.type) if features is not None else None
57+
self._type: pa.DataType = pa.struct(field for field in self._schema)
58+
elif data_type is not None:
5459
self._type: pa.DataType = data_type
5560
self._schema: pa.Schema = pa.schema(field for field in self._type)
61+
self._features = Features.from_arrow_schema(self._schema)
5662
elif schema is not None:
5763
self._schema: pa.Schema = schema
5864
self._type: pa.DataType = pa.struct(field for field in self._schema)
65+
self._features = Features.from_arrow_schema(self._schema)
5966
else:
67+
self._features = None
6068
self._schema = None
6169
self._type = None
6270

6371
if disable_nullable and self._schema is not None:
6472
self._schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in self._type)
6573
self._type = pa.struct(pa.field(field.name, field.type, nullable=False) for field in self._type)
74+
self._features = Features.from_arrow_schema(self._schema)
6675

6776
self._path = path
6877
if stream is None:
@@ -76,19 +85,15 @@ def __init__(
7685
self._num_bytes = 0
7786
self.current_rows = []
7887

79-
self._build_writer(schema=self._schema)
88+
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
89+
if self._schema is not None:
90+
self._build_writer(schema=self._schema)
8091

81-
def _build_writer(self, pa_table=None, schema=None):
82-
if schema is not None:
83-
self._schema: pa.Schema = schema
84-
self._type: pa.DataType = pa.struct(field for field in self._schema)
85-
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)
86-
elif pa_table is not None:
87-
self._schema: pa.Schema = pa_table.schema
88-
self._type: pa.DataType = pa.struct(field for field in self._schema)
89-
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
90-
else:
91-
self.pa_writer = None
92+
def _build_writer(self, schema: pa.Schema):
93+
self._schema: pa.Schema = schema
94+
self._type: pa.DataType = pa.struct(field for field in self._schema)
95+
self._features = Features.from_arrow_schema(self._schema)
96+
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)
9297

9398
@property
9499
def schema(self):
@@ -98,6 +103,9 @@ def _write_array_on_file(self, pa_array):
98103
"""Write a PyArrow Array"""
99104
pa_batch = pa.RecordBatch.from_struct_array(pa_array)
100105
self._num_bytes += pa_array.nbytes
106+
if self.pa_writer is None:
107+
pa_table = pa.Table.from_batches([pa_batch])
108+
self._build_writer(schema=pa_table.schema)
101109
self.pa_writer.write_batch(pa_batch)
102110

103111
def write_on_file(self):
@@ -141,8 +149,6 @@ def write(self, example: Dict[str, Any], writer_batch_size: Optional[int] = None
141149
self._num_examples += 1
142150
if writer_batch_size is None:
143151
writer_batch_size = self.writer_batch_size
144-
if self.pa_writer is None:
145-
self._build_writer(pa_table=pa.Table.from_pydict(example))
146152
if writer_batch_size is not None and len(self.current_rows) >= writer_batch_size:
147153
self.write_on_file()
148154

@@ -156,7 +162,7 @@ def write_batch(
156162
"""
157163
batch_examples = map_all_sequences_to_lists(batch_examples)
158164
if self.pa_writer is None:
159-
self._build_writer(pa_table=pa.Table.from_pydict(batch_examples))
165+
self._build_writer(schema=pa.Table.from_pydict(batch_examples).schema)
160166
pa_table: pa.Table = pa.Table.from_pydict(batch_examples, schema=self._schema)
161167
if writer_batch_size is None:
162168
writer_batch_size = self.writer_batch_size
@@ -175,17 +181,16 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non
175181
if writer_batch_size is None:
176182
writer_batch_size = self.writer_batch_size
177183
if self.pa_writer is None:
178-
self._build_writer(pa_table=pa_table)
184+
self._build_writer(schema=pa_table.schema)
179185
batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size)
180186
self._num_bytes += sum(batch.nbytes for batch in batches)
181187
self._num_examples += pa_table.num_rows
182188
for batch in batches:
183189
self.pa_writer.write_batch(batch)
184190

185191
def finalize(self, close_stream=True):
186-
if self.pa_writer is not None:
187-
self.write_on_file()
188-
self.pa_writer.close()
192+
self.write_on_file()
193+
self.pa_writer.close()
189194
if close_stream:
190195
self.stream.close()
191196
logger.info(

src/nlp/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class Value:
5555
_type: str = field(default="Value", init=False, repr=False)
5656

5757
def __post_init__(self):
58+
if self.dtype == "double": # fix inferred type
59+
self.dtype = "float64"
60+
if self.dtype == "float": # fix inferred type
61+
self.dtype = "float32"
5862
self.pa_type = string_to_arrow(self.dtype)
5963

6064
def __call__(self):

0 commit comments

Comments
 (0)