Skip to content
Merged
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def cast_column(self, column: str, feature: FeatureType, new_fingerprint: str) -
'text': Value(dtype='string', id=None)}
```
"""
if hasattr(feature, "cast_storage"):
if hasattr(feature, "decode_example"):
dataset = copy.deepcopy(self)
dataset.features[column] = feature
dataset._fingerprint = new_fingerprint
Expand Down
87 changes: 68 additions & 19 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.types
from pandas.api.extensions import ExtensionArray as PandasExtensionArray
from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype

from .. import config
from ..table import array_cast
from ..utils import logging
from ..utils.py_utils import first_non_null_value, zip_dict
from .audio import Audio
Expand Down Expand Up @@ -845,23 +847,29 @@ def str2int(self, values: Union[str, Iterable]):
values = [values]
return_list = False

output = []
for value in values:
if self._str2int:
# strip key if not in dict
if value not in self._str2int:
value = str(value).strip()
output.append(self._str2int[str(value)])
else:
# No names provided, try to integerize
failed_parse = False
output = [self._strval2int(value) for value in values]
return output if return_list else output[0]

def _strval2int(self, value: str):
failed_parse = False
value = str(value)
# first attempt - raw string value
int_value = self._str2int.get(value)
if int_value is None:
# second attempt - strip whitespace
int_value = self._str2int.get(value.strip())
if int_value is None:
# third attempt - convert str to int
try:
output.append(int(value))
int_value = int(value)
except ValueError:
failed_parse = True
if failed_parse or not 0 <= value < self.num_classes:
raise ValueError(f"Invalid string class label {value}")
return output if return_list else output[0]
else:
if int_value < -1 or int_value >= self.num_classes:
failed_parse = True
if failed_parse:
raise ValueError(f"Invalid string class label {value}")
return int_value

def int2str(self, values: Union[int, Iterable]):
"""Conversion integer => class name string."""
Expand All @@ -878,11 +886,7 @@ def int2str(self, values: Union[int, Iterable]):
if not 0 <= v < self.num_classes:
raise ValueError(f"Invalid integer class label {v:d}")

if self._int2str:
output = [self._int2str[int(v)] for v in values]
else:
# No names provided, return str(values)
output = [str(v) for v in values]
output = [self._int2str[int(v)] for v in values]
return output if return_list else output[0]

def encode_example(self, example_data):
Expand All @@ -901,6 +905,33 @@ def encode_example(self, example_data):
raise ValueError(f"Class label {example_data:d} greater than configured num_classes {self.num_classes}")
return example_data

def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.Int64Array:
"""Cast an Arrow array to the ClassLabel arrow storage type.
The Arrow types that can be converted to the ClassLabel pyarrow storage type are:

- pa.string()
- pa.int()

Args:
storage (Union[pa.StringArray, pa.IntegerArray]): PyArrow array to cast.

Returns:
pa.Int64Array: Array in the ClassLabel arrow storage type
"""
if isinstance(storage, pa.IntegerArray):
min_max = pc.min_max(storage).as_py()
if min_max["min"] < -1:
raise ValueError(f"Class label {min_max['min']} less than -1")
if min_max["max"] >= self.num_classes:
raise ValueError(
f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}"
)
elif isinstance(storage, pa.StringArray):
storage = pa.array(
[self._strval2int(label) if label is not None else None for label in storage.to_pylist()]
)
return array_cast(storage, self.pa_type)

@staticmethod
def _load_names_from_file(names_filepath):
with open(names_filepath, encoding="utf-8") as f:
Expand Down Expand Up @@ -1265,6 +1296,24 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False
return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True)


def require_storage_cast(feature: FeatureType) -> bool:
"""Check if a (possibly nested) feature requires storage casting.

Args:
feature (FeatureType): the feature type to be checked
Returns:
:obj:`bool`
"""
if isinstance(feature, dict):
return any(require_storage_cast(f) for f in feature.values())
elif isinstance(feature, (list, tuple)):
return require_storage_cast(feature[0])
elif isinstance(feature, Sequence):
return require_storage_cast(feature.feature)
else:
return hasattr(feature, "cast_storage")


def keep_features_dicts_synced(func):
"""
Wrapper to keep the secondary dictionary, which tracks whether keys are decodable, of the :class:`datasets.Features` object
Expand Down
28 changes: 24 additions & 4 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import datasets
import datasets.config
from datasets.features.features import require_storage_cast
from datasets.table import table_cast


logger = datasets.utils.logging.get_logger(__name__)
Expand Down Expand Up @@ -146,19 +148,37 @@ def _split_generators(self, dl_manager):
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table

def _generate_tables(self, files):
schema = pa.schema(self.config.features.type) if self.config.features is not None else None
schema = self.config.features.arrow_schema if self.config.features else None
# dtype allows reading an int column as str
dtype = {name: dtype.to_pandas_dtype() for name, dtype in zip(schema.names, schema.types)} if schema else None
dtype = (
{
name: dtype.to_pandas_dtype() if not require_storage_cast(feature) else object
for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values())
}
if schema is not None
else None
)
for file_idx, file in enumerate(files):
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
try:
for batch_idx, df in enumerate(csv_file_reader):
pa_table = pa.Table.from_pandas(df, schema=schema)
pa_table = pa.Table.from_pandas(df)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), pa_table
yield (file_idx, batch_idx), self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
28 changes: 7 additions & 21 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ class JsonConfig(datasets.BuilderConfig):
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None

@property
def schema(self):
return self.features.arrow_schema if self.features is not None else None


class Json(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = JsonConfig
Expand Down Expand Up @@ -64,21 +60,11 @@ def _split_generators(self, dl_manager):
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
return splits

def _cast_classlabels(self, pa_table: pa.Table) -> pa.Table:
if self.config.features:
# Encode column if ClassLabel
for i, col in enumerate(self.config.features.keys()):
if isinstance(self.config.features[col], datasets.ClassLabel):
if pa_table[col].type == pa.string():
pa_table = pa_table.set_column(
i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])]
)
elif pa_table[col].type != self.config.schema.field(col).type:
raise ValueError(
f"Field '{col}' from the JSON data of type {pa_table[col].type} is not compatible with ClassLabel. Compatible types are int64 and string."
)
# Cast allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.config.schema)
def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
# more expensive cast to support nested structures with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
return pa_table

def _generate_tables(self, files):
Expand All @@ -98,7 +84,7 @@ def _generate_tables(self, files):
else:
mapping = dataset
pa_table = pa.Table.from_pydict(mapping=mapping)
yield file_idx, self._cast_classlabels(pa_table)
yield file_idx, self._cast_table(pa_table)

# If the file has one json object per line
else:
Expand Down Expand Up @@ -153,5 +139,5 @@ def _generate_tables(self, files):
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), self._cast_classlabels(pa_table)
yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1
24 changes: 22 additions & 2 deletions src/datasets/packaged_modules/pandas/pandas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from dataclasses import dataclass
from typing import Optional

import pandas as pd
import pyarrow as pa

import datasets
from datasets.table import table_cast


@dataclass
class PandasConfig(datasets.BuilderConfig):
"""BuilderConfig for Pandas."""

features: Optional[datasets.Features] = None


class Pandas(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = PandasConfig

def _info(self):
return datasets.DatasetInfo()
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
Expand All @@ -25,8 +38,15 @@ def _split_generators(self, dl_manager):
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
# more expensive cast to support nested features with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
return pa_table

def _generate_tables(self, files):
for i, file in enumerate(files):
with open(file, "rb") as f:
pa_table = pa.Table.from_pandas(pd.read_pickle(f))
yield i, pa_table
yield i, self._cast_table(pa_table)
14 changes: 10 additions & 4 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pyarrow.parquet as pq

import datasets
from datasets.table import table_cast


logger = datasets.utils.logging.get_logger(__name__)
Expand Down Expand Up @@ -42,8 +43,15 @@ def _split_generators(self, dl_manager):
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
# more expensive cast to support nested features with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
return pa_table

def _generate_tables(self, files):
schema = pa.schema(self.config.features.type) if self.config.features is not None else None
schema = self.config.features.arrow_schema if self.config.features is not None else None
if self.config.features is not None and self.config.columns is not None:
if sorted(field.name for field in schema) != sorted(self.config.columns):
raise ValueError(
Expand All @@ -57,12 +65,10 @@ def _generate_tables(self, files):
parquet_file.iter_batches(batch_size=self.config.batch_size, columns=self.config.columns)
):
pa_table = pa.Table.from_batches([record_batch])
if self.config.features is not None:
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", pa_table
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
29 changes: 22 additions & 7 deletions src/datasets/packaged_modules/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pyarrow as pa

import datasets
from datasets.features.features import require_storage_cast
from datasets.table import table_cast


logger = datasets.utils.logging.get_logger(__name__)
Expand Down Expand Up @@ -50,8 +52,21 @@ def _split_generators(self, dl_manager):
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa_table.cast(schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table
else:
return pa_table.cast(pa.schema({"text": pa.string()}))

def _generate_tables(self, files):
schema = pa.schema(self.config.features.type if self.config.features is not None else {"text": pa.string()})
pa_table_names = list(self.config.features) if self.config.features is not None else ["text"]
for file_idx, file in enumerate(files):
# open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n"
with open(file, encoding=self.config.encoding) as f:
Expand All @@ -66,11 +81,11 @@ def _generate_tables(self, files):
batch = StringIO(batch).readlines()
if not self.config.keep_linebreaks:
batch = [line.rstrip("\n") for line in batch]
pa_table = pa.Table.from_arrays([pa.array(batch)], schema=schema)
pa_table = pa.Table.from_arrays([pa.array(batch)], names=pa_table_names)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), pa_table
yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1
elif self.config.sample_by == "paragraph":
batch_idx = 0
Expand All @@ -82,15 +97,15 @@ def _generate_tables(self, files):
batch += f.readline() # finish current line
batch = batch.split("\n\n")
pa_table = pa.Table.from_arrays(
[pa.array([example for example in batch[:-1] if example])], schema=schema
[pa.array([example for example in batch[:-1] if example])], names=pa_table_names
)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), pa_table
yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1
batch = batch[-1]
elif self.config.sample_by == "document":
text = f.read()
pa_table = pa.Table.from_arrays([pa.array([text])], schema=schema)
yield file_idx, pa_table
pa_table = pa.Table.from_arrays([pa.array([text])], names=pa_table_names)
yield file_idx, self._cast_table(pa_table)
Loading