Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 31 additions & 58 deletions src/nlp/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .info import DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .utils import map_all_sequences_to_lists, map_nested
from .utils import map_nested


logger = logging.getLogger(__name__)
Expand All @@ -49,7 +49,7 @@ class DatasetInfoMixin(object):
at the base level of the Dataset for easy access.
"""

def __init__(self, info: Optional[DatasetInfo], split: Optional[NamedSplit]):
def __init__(self, info: DatasetInfo, split: Optional[NamedSplit]):
self._info = info
self._split = split

Expand Down Expand Up @@ -92,7 +92,7 @@ def download_size(self) -> Optional[int]:
return self._info.download_size

@property
def features(self):
def features(self) -> Features:
return self._info.features

@property
Expand Down Expand Up @@ -131,6 +131,7 @@ def __init__(
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
):
info = info.copy() if info is not None else DatasetInfo()
DatasetInfoMixin.__init__(self, info=info, split=split)
IndexableMixin.__init__(self)
self._data: pa.Table = arrow_table
Expand All @@ -139,6 +140,15 @@ def __init__(
self._format_kwargs: dict = {}
self._format_columns: Optional[list] = None
self._output_all_columns: bool = False
inferred_features = Features.from_arrow_schema(arrow_table.schema)
if self.info.features is not None:
if self.info.features.type != inferred_features.type:
self.info.features = inferred_features
else:
pass # keep the original features
else:
self.info.features = inferred_features
assert self.features is not None, "Features can't be None in a Dataset object"

@classmethod
def from_file(
Expand Down Expand Up @@ -177,7 +187,7 @@ def from_pandas(

Be aware that Series of the `object` dtype don't carry enough information to always lead to a meaningful Arrow type. In the case that
we cannot infer a type, e.g. because the DataFrame is of length 0 or the Series only contains None/nan objects, the type is set to
null. This behavior can be avoided by constructing an explicit schema and passing it to this function.
null. This behavior can be avoided by constructing explicit features and passing it to this function.

Args:
df (:obj:``pandas.DataFrame``): the dataframe that contains the dataset.
Expand All @@ -186,21 +196,16 @@ def from_pandas(
description, citation, etc.
split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
"""
if info is None:
info = DatasetInfo()
if info.features is None:
info.features = features
elif info.features != features and features is not None:
if info is not None and features is not None and info.features != features:
raise ValueError(
"Features specified in `features` and `info.features` can't be different:\n{}\n{}".format(
features, info.features
)
)
features = features if features is not None else info.feature if info is not None else None
pa_table: pa.Table = pa.Table.from_pandas(
df=df, schema=pa.schema(info.features.type) if info.features is not None else None
df=df, schema=pa.schema(features.type) if features is not None else None
)
if info.features is None:
info.features = Features.from_arrow_schema(pa_table.schema)
return cls(pa_table, info=info, split=split)

@classmethod
Expand All @@ -221,21 +226,16 @@ def from_dict(
description, citation, etc.
split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
"""
if info is None:
info = DatasetInfo()
if info.features is None:
info.features = features
elif info.features != features and features is not None:
if info is not None and features is not None and info.features != features:
raise ValueError(
"Features specified in `features` and `info.features` can't be different:\n{}\n{}".format(
features, info.features
)
)
features = features if features is not None else info.feature if info is not None else None
pa_table: pa.Table = pa.Table.from_pydict(
mapping=mapping, schema=pa.schema(info.features.type) if info.features is not None else None
mapping=mapping, schema=pa.schema(features.type) if features is not None else None
)
if info.features is None:
info.features = Features.from_arrow_schema(pa_table.schema)
return cls(pa_table, info=info, split=split)

@property
Expand Down Expand Up @@ -277,14 +277,6 @@ def column_names(self) -> List[str]:
"""Names of the columns in the dataset. """
return self._data.column_names

@property
def schema(self) -> pa.Schema:
"""The Arrow schema of the Apache Arrow table backing the dataset.
You probably don't need to access directly this and can rather use
:func:`nlp.Dataset.features` to inspect the dataset features.
"""
return self._data.schema

@property
def shape(self):
"""Shape of the dataset (number of columns, number of rows)."""
Expand Down Expand Up @@ -340,6 +332,7 @@ def dictionary_encode_column(self, column: str):
casted_field = pa.field(field.name, pa.dictionary(pa.int32(), field.type), nullable=False)
casted_schema.set(field_index, casted_field)
self._data = self._data.cast(casted_schema)
self.info.features = Features.from_arrow_schema(self._data.schema)

def flatten(self, max_depth=16):
""" Flatten the Table.
Expand All @@ -352,7 +345,7 @@ def flatten(self, max_depth=16):
else:
break
if self.info is not None:
self.info.features = Features.from_arrow_schema(self.schema)
self.info.features = Features.from_arrow_schema(self._data.schema)
logger.info(
"Flattened dataset from depth {} to depth {}.".format(depth, 1 if depth + 1 < max_depth else "unknown")
)
Expand Down Expand Up @@ -380,8 +373,7 @@ def __iter__(self):
)

def __repr__(self):
schema_str = dict((a, str(b)) for a, b in zip(self._data.schema.names, self._data.schema.types))
return f"Dataset(schema: {schema_str}, num_rows: {self.num_rows})"
return f"Dataset(features: {self.features}, num_rows: {self.num_rows})"

@property
def format(self):
Expand Down Expand Up @@ -685,7 +677,7 @@ def map(
load_from_cache_file: bool = True,
cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
arrow_schema: Optional[pa.Schema] = None,
features: Optional[Features] = None,
disable_nullable: bool = True,
verbose: bool = True,
):
Expand All @@ -712,7 +704,7 @@ def map(
results of the computation instead of the automatically generated cache file name.
`writer_batch_size` (`int`, default: `1000`): Number of rows per write operation for the cache file writer.
Higher value gives smaller cache files, lower value consume less temporary memory while running `.map()`.
`arrow_schema` (`Optional[pa.Schema]`, default: `None`): Use a specific Apache Arrow Schema to store the cache file
`features` (`Optional[nlp.Features]`, default: `None`): Use a specific Features to store the cache file
instead of the automatically generated one.
`disable_nullable` (`bool`, default: `True`): Allow null values in the table.
`verbose` (`bool`, default: `True`): Set to `False` to deactivate the tqdm progress bar and informations.
Expand Down Expand Up @@ -792,18 +784,6 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
inputs.update(processed_inputs)
return inputs

# Find the output schema if none is given
test_inputs = self[:2] if batched else self[0]
test_indices = [0, 1] if batched else 0
test_output = apply_function_on_filtered_inputs(test_inputs, test_indices)
if arrow_schema is None and update_data:
if not batched:
test_output = self._nest(test_output)
test_output = map_all_sequences_to_lists(test_output)
arrow_schema = pa.Table.from_pydict(test_output).schema
if disable_nullable:
arrow_schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in arrow_schema)

# Check if we've already cached this computation (indexed by a hash)
if self._data_files and update_data:
if cache_file_name is None:
Expand All @@ -817,7 +797,7 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
"load_from_cache_file": load_from_cache_file,
"cache_file_name": cache_file_name,
"writer_batch_size": writer_batch_size,
"arrow_schema": arrow_schema,
"features": features,
"disable_nullable": disable_nullable,
}
cache_file_name = self._get_cache_file_path(function, cache_kwargs)
Expand All @@ -830,12 +810,12 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
if update_data:
if keep_in_memory or not self._data_files:
buf_writer = pa.BufferOutputStream()
writer = ArrowWriter(schema=arrow_schema, stream=buf_writer, writer_batch_size=writer_batch_size)
writer = ArrowWriter(features=features, stream=buf_writer, writer_batch_size=writer_batch_size)
else:
buf_writer = None
if verbose:
logger.info("Caching processed dataset at %s", cache_file_name)
writer = ArrowWriter(schema=arrow_schema, path=cache_file_name, writer_batch_size=writer_batch_size)
writer = ArrowWriter(features=features, path=cache_file_name, writer_batch_size=writer_batch_size)

# Loop over single examples or batches and write to buffer/file if examples are to be updated
if not batched:
Expand Down Expand Up @@ -928,15 +908,8 @@ def map_function(batch, *args):

return result

# to avoid errors with the arrow_schema we define it here
test_inputs = self[:2]
if "remove_columns" in kwargs:
test_inputs = {key: test_inputs[key] for key in (test_inputs.keys() - kwargs["remove_columns"])}
test_inputs = map_all_sequences_to_lists(test_inputs)
arrow_schema = pa.Table.from_pydict(test_inputs).schema

# return map function
return self.map(map_function, batched=True, with_indices=with_indices, arrow_schema=arrow_schema, **kwargs)
return self.map(map_function, batched=True, with_indices=with_indices, features=self.features, **kwargs)

def select(
self,
Expand Down Expand Up @@ -991,12 +964,12 @@ def select(
# Prepare output buffer and batched writer in memory or on file if we update the table
if keep_in_memory or not self._data_files:
buf_writer = pa.BufferOutputStream()
writer = ArrowWriter(schema=self.schema, stream=buf_writer, writer_batch_size=writer_batch_size)
writer = ArrowWriter(features=self.features, stream=buf_writer, writer_batch_size=writer_batch_size)
else:
buf_writer = None
if verbose:
logger.info("Caching processed dataset at %s", cache_file_name)
writer = ArrowWriter(schema=self.schema, path=cache_file_name, writer_batch_size=writer_batch_size)
writer = ArrowWriter(features=self.features, path=cache_file_name, writer_batch_size=writer_batch_size)

# Loop over batches and write to buffer/file if examples are to be updated
for i in tqdm(range(0, len(indices), reader_batch_size), disable=not verbose):
Expand Down
47 changes: 26 additions & 21 deletions src/nlp/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import pyarrow as pa

from .features import Features
from .utils.file_utils import HF_DATASETS_CACHE, hash_url_to_filename
from .utils.py_utils import map_all_sequences_to_lists

Expand All @@ -42,27 +43,35 @@ def __init__(
self,
data_type: Optional[pa.DataType] = None,
schema: Optional[pa.Schema] = None,
features: Optional[Features] = None,
path: Optional[str] = None,
stream: Optional[pa.NativeFile] = None,
writer_batch_size: Optional[int] = None,
disable_nullable: bool = True,
):
if path is None and stream is None:
raise ValueError("At least one of path and stream must be provided.")

if data_type is not None:
if features is not None:
self._features = features
self._schema = pa.schema(features.type) if features is not None else None
self._type: pa.DataType = pa.struct(field for field in self._schema)
elif data_type is not None:
self._type: pa.DataType = data_type
self._schema: pa.Schema = pa.schema(field for field in self._type)
self._features = Features.from_arrow_schema(self._schema)
elif schema is not None:
self._schema: pa.Schema = schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
self._features = Features.from_arrow_schema(self._schema)
else:
self._features = None
self._schema = None
self._type = None

if disable_nullable and self._schema is not None:
self._schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in self._type)
self._type = pa.struct(pa.field(field.name, field.type, nullable=False) for field in self._type)
self._features = Features.from_arrow_schema(self._schema)

self._path = path
if stream is None:
Expand All @@ -76,19 +85,15 @@ def __init__(
self._num_bytes = 0
self.current_rows = []

self._build_writer(schema=self._schema)
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
if self._schema is not None:
self._build_writer(schema=self._schema)

def _build_writer(self, pa_table=None, schema=None):
if schema is not None:
self._schema: pa.Schema = schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)
elif pa_table is not None:
self._schema: pa.Schema = pa_table.schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema)
else:
self.pa_writer = None
def _build_writer(self, schema: pa.Schema):
self._schema: pa.Schema = schema
self._type: pa.DataType = pa.struct(field for field in self._schema)
self._features = Features.from_arrow_schema(self._schema)
self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)

@property
def schema(self):
Expand All @@ -98,6 +103,9 @@ def _write_array_on_file(self, pa_array):
"""Write a PyArrow Array"""
pa_batch = pa.RecordBatch.from_struct_array(pa_array)
self._num_bytes += pa_array.nbytes
if self.pa_writer is None:
pa_table = pa.Table.from_batches([pa_batch])
self._build_writer(schema=pa_table.schema)
self.pa_writer.write_batch(pa_batch)

def write_on_file(self):
Expand Down Expand Up @@ -141,8 +149,6 @@ def write(self, example: Dict[str, Any], writer_batch_size: Optional[int] = None
self._num_examples += 1
if writer_batch_size is None:
writer_batch_size = self.writer_batch_size
if self.pa_writer is None:
self._build_writer(pa_table=pa.Table.from_pydict(example))
if writer_batch_size is not None and len(self.current_rows) >= writer_batch_size:
self.write_on_file()

Expand All @@ -156,7 +162,7 @@ def write_batch(
"""
batch_examples = map_all_sequences_to_lists(batch_examples)
if self.pa_writer is None:
self._build_writer(pa_table=pa.Table.from_pydict(batch_examples))
self._build_writer(schema=pa.Table.from_pydict(batch_examples).schema)
pa_table: pa.Table = pa.Table.from_pydict(batch_examples, schema=self._schema)
if writer_batch_size is None:
writer_batch_size = self.writer_batch_size
Expand All @@ -175,17 +181,16 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non
if writer_batch_size is None:
writer_batch_size = self.writer_batch_size
if self.pa_writer is None:
self._build_writer(pa_table=pa_table)
self._build_writer(schema=pa_table.schema)
batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size)
self._num_bytes += sum(batch.nbytes for batch in batches)
self._num_examples += pa_table.num_rows
for batch in batches:
self.pa_writer.write_batch(batch)

def finalize(self, close_stream=True):
if self.pa_writer is not None:
self.write_on_file()
self.pa_writer.close()
self.write_on_file()
self.pa_writer.close()
if close_stream:
self.stream.close()
logger.info(
Expand Down
4 changes: 4 additions & 0 deletions src/nlp/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class Value:
_type: str = field(default="Value", init=False, repr=False)

def __post_init__(self):
if self.dtype == "double": # fix inferred type
self.dtype = "float64"
if self.dtype == "float": # fix inferred type
self.dtype = "float32"
self.pa_type = string_to_arrow(self.dtype)

def __call__(self):
Expand Down
Loading