diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 01ee0ef9064..2f2473624fc 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1548,7 +1548,7 @@ def cast_column(self, column: str, feature: FeatureType, new_fingerprint: str) - 'text': Value(dtype='string', id=None)} ``` """ - if hasattr(feature, "cast_storage"): + if hasattr(feature, "decode_example"): dataset = copy.deepcopy(self) dataset.features[column] = feature dataset._fingerprint = new_fingerprint diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 853b117ee76..e9d4df309aa 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -29,11 +29,13 @@ import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pyarrow.types from pandas.api.extensions import ExtensionArray as PandasExtensionArray from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype from .. import config +from ..table import array_cast from ..utils import logging from ..utils.py_utils import first_non_null_value, zip_dict from .audio import Audio @@ -845,23 +847,29 @@ def str2int(self, values: Union[str, Iterable]): values = [values] return_list = False - output = [] - for value in values: - if self._str2int: - # strip key if not in dict - if value not in self._str2int: - value = str(value).strip() - output.append(self._str2int[str(value)]) - else: - # No names provided, try to integerize - failed_parse = False + output = [self._strval2int(value) for value in values] + return output if return_list else output[0] + + def _strval2int(self, value: str): + failed_parse = False + value = str(value) + # first attempt - raw string value + int_value = self._str2int.get(value) + if int_value is None: + # second attempt - strip whitespace + int_value = self._str2int.get(value.strip()) + if int_value is None: + # third attempt - convert str to int try: - output.append(int(value)) + int_value = int(value) except ValueError: failed_parse = True - if failed_parse or not 0 <= value < self.num_classes: - raise ValueError(f"Invalid string class label {value}") - return output if return_list else output[0] + else: + if int_value < -1 or int_value >= self.num_classes: + failed_parse = True + if failed_parse: + raise ValueError(f"Invalid string class label {value}") + return int_value def int2str(self, values: Union[int, Iterable]): """Conversion integer => class name string.""" @@ -878,11 +886,7 @@ def int2str(self, values: Union[int, Iterable]): if not 0 <= v < self.num_classes: raise ValueError(f"Invalid integer class label {v:d}") - if self._int2str: - output = [self._int2str[int(v)] for v in values] - else: - # No names provided, return str(values) - output = [str(v) for v in values] + output = [self._int2str[int(v)] for v in values] return output if return_list else output[0] def encode_example(self, example_data): @@ -901,6 +905,33 @@ def encode_example(self, example_data): raise ValueError(f"Class label {example_data:d} greater than configured num_classes {self.num_classes}") return example_data + def cast_storage(self, storage: Union[pa.StringArray, pa.IntegerArray]) -> pa.Int64Array: + """Cast an Arrow array to the ClassLabel arrow storage type. + The Arrow types that can be converted to the ClassLabel pyarrow storage type are: + + - pa.string() + - pa.int() + + Args: + storage (Union[pa.StringArray, pa.IntegerArray]): PyArrow array to cast. + + Returns: + pa.Int64Array: Array in the ClassLabel arrow storage type + """ + if isinstance(storage, pa.IntegerArray): + min_max = pc.min_max(storage).as_py() + if min_max["min"] < -1: + raise ValueError(f"Class label {min_max['min']} less than -1") + if min_max["max"] >= self.num_classes: + raise ValueError( + f"Class label {min_max['max']} greater than configured num_classes {self.num_classes}" + ) + elif isinstance(storage, pa.StringArray): + storage = pa.array( + [self._strval2int(label) if label is not None else None for label in storage.to_pylist()] + ) + return array_cast(storage, self.pa_type) + @staticmethod def _load_names_from_file(names_filepath): with open(names_filepath, encoding="utf-8") as f: @@ -1265,6 +1296,24 @@ def require_decoding(feature: FeatureType, ignore_decode_attribute: bool = False return hasattr(feature, "decode_example") and (feature.decode if not ignore_decode_attribute else True) +def require_storage_cast(feature: FeatureType) -> bool: + """Check if a (possibly nested) feature requires storage casting. + + Args: + feature (FeatureType): the feature type to be checked + Returns: + :obj:`bool` + """ + if isinstance(feature, dict): + return any(require_storage_cast(f) for f in feature.values()) + elif isinstance(feature, (list, tuple)): + return require_storage_cast(feature[0]) + elif isinstance(feature, Sequence): + return require_storage_cast(feature.feature) + else: + return hasattr(feature, "cast_storage") + + def keep_features_dicts_synced(func): """ Wrapper to keep the secondary dictionary, which tracks whether keys are decodable, of the :class:`datasets.Features` object diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index bf17d1367a6..fb3a77855dc 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -7,6 +7,8 @@ import datasets import datasets.config +from datasets.features.features import require_storage_cast +from datasets.table import table_cast logger = datasets.utils.logging.get_logger(__name__) @@ -146,19 +148,37 @@ def _split_generators(self, dl_manager): splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)})) return splits + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + # cheaper cast + pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) + else: + # more expensive cast; allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, schema) + return pa_table + def _generate_tables(self, files): - schema = pa.schema(self.config.features.type) if self.config.features is not None else None + schema = self.config.features.arrow_schema if self.config.features else None # dtype allows reading an int column as str - dtype = {name: dtype.to_pandas_dtype() for name, dtype in zip(schema.names, schema.types)} if schema else None + dtype = ( + { + name: dtype.to_pandas_dtype() if not require_storage_cast(feature) else object + for name, dtype, feature in zip(schema.names, schema.types, self.config.features.values()) + } + if schema is not None + else None + ) for file_idx, file in enumerate(files): csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs) try: for batch_idx, df in enumerate(csv_file_reader): - pa_table = pa.Table.from_pandas(df, schema=schema) + pa_table = pa.Table.from_pandas(df) # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield (file_idx, batch_idx), pa_table + yield (file_idx, batch_idx), self._cast_table(pa_table) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 1e3a1949244..86ef8eb7d5e 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -25,10 +25,6 @@ class JsonConfig(datasets.BuilderConfig): chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None - @property - def schema(self): - return self.features.arrow_schema if self.features is not None else None - class Json(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = JsonConfig @@ -64,21 +60,11 @@ def _split_generators(self, dl_manager): splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)})) return splits - def _cast_classlabels(self, pa_table: pa.Table) -> pa.Table: - if self.config.features: - # Encode column if ClassLabel - for i, col in enumerate(self.config.features.keys()): - if isinstance(self.config.features[col], datasets.ClassLabel): - if pa_table[col].type == pa.string(): - pa_table = pa_table.set_column( - i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])] - ) - elif pa_table[col].type != self.config.schema.field(col).type: - raise ValueError( - f"Field '{col}' from the JSON data of type {pa_table[col].type} is not compatible with ClassLabel. Compatible types are int64 and string." - ) - # Cast allows str <-> int/float or str to Audio for example - pa_table = table_cast(pa_table, self.config.schema) + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + # more expensive cast to support nested structures with keys in a different order + # allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, self.config.features.arrow_schema) return pa_table def _generate_tables(self, files): @@ -98,7 +84,7 @@ def _generate_tables(self, files): else: mapping = dataset pa_table = pa.Table.from_pydict(mapping=mapping) - yield file_idx, self._cast_classlabels(pa_table) + yield file_idx, self._cast_table(pa_table) # If the file has one json object per line else: @@ -153,5 +139,5 @@ def _generate_tables(self, files): # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield (file_idx, batch_idx), self._cast_classlabels(pa_table) + yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 diff --git a/src/datasets/packaged_modules/pandas/pandas.py b/src/datasets/packaged_modules/pandas/pandas.py index 2e7c090fed3..88e358529b5 100644 --- a/src/datasets/packaged_modules/pandas/pandas.py +++ b/src/datasets/packaged_modules/pandas/pandas.py @@ -1,12 +1,25 @@ +from dataclasses import dataclass +from typing import Optional + import pandas as pd import pyarrow as pa import datasets +from datasets.table import table_cast + + +@dataclass +class PandasConfig(datasets.BuilderConfig): + """BuilderConfig for Pandas.""" + + features: Optional[datasets.Features] = None class Pandas(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = PandasConfig + def _info(self): - return datasets.DatasetInfo() + return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): """We handle string, list and dicts in datafiles""" @@ -25,8 +38,15 @@ def _split_generators(self, dl_manager): splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) return splits + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + # more expensive cast to support nested features with keys in a different order + # allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, self.config.features.arrow_schema) + return pa_table + def _generate_tables(self, files): for i, file in enumerate(files): with open(file, "rb") as f: pa_table = pa.Table.from_pandas(pd.read_pickle(f)) - yield i, pa_table + yield i, self._cast_table(pa_table) diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 2f819f3a91f..23e6e664126 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -5,6 +5,7 @@ import pyarrow.parquet as pq import datasets +from datasets.table import table_cast logger = datasets.utils.logging.get_logger(__name__) @@ -42,8 +43,15 @@ def _split_generators(self, dl_manager): splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) return splits + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + # more expensive cast to support nested features with keys in a different order + # allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, self.config.features.arrow_schema) + return pa_table + def _generate_tables(self, files): - schema = pa.schema(self.config.features.type) if self.config.features is not None else None + schema = self.config.features.arrow_schema if self.config.features is not None else None if self.config.features is not None and self.config.columns is not None: if sorted(field.name for field in schema) != sorted(self.config.columns): raise ValueError( @@ -57,12 +65,10 @@ def _generate_tables(self, files): parquet_file.iter_batches(batch_size=self.config.batch_size, columns=self.config.columns) ): pa_table = pa.Table.from_batches([record_batch]) - if self.config.features is not None: - pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield f"{file_idx}_{batch_idx}", pa_table + yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise diff --git a/src/datasets/packaged_modules/text/text.py b/src/datasets/packaged_modules/text/text.py index a6df99bbb40..38800cda307 100644 --- a/src/datasets/packaged_modules/text/text.py +++ b/src/datasets/packaged_modules/text/text.py @@ -5,6 +5,8 @@ import pyarrow as pa import datasets +from datasets.features.features import require_storage_cast +from datasets.table import table_cast logger = datasets.utils.logging.get_logger(__name__) @@ -50,8 +52,21 @@ def _split_generators(self, dl_manager): splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": dl_manager.iter_files(files)})) return splits + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + # cheaper cast + pa_table = pa_table.cast(schema) + else: + # more expensive cast; allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, schema) + return pa_table + else: + return pa_table.cast(pa.schema({"text": pa.string()})) + def _generate_tables(self, files): - schema = pa.schema(self.config.features.type if self.config.features is not None else {"text": pa.string()}) + pa_table_names = list(self.config.features) if self.config.features is not None else ["text"] for file_idx, file in enumerate(files): # open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n" with open(file, encoding=self.config.encoding) as f: @@ -66,11 +81,11 @@ def _generate_tables(self, files): batch = StringIO(batch).readlines() if not self.config.keep_linebreaks: batch = [line.rstrip("\n") for line in batch] - pa_table = pa.Table.from_arrays([pa.array(batch)], schema=schema) + pa_table = pa.Table.from_arrays([pa.array(batch)], names=pa_table_names) # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield (file_idx, batch_idx), pa_table + yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 elif self.config.sample_by == "paragraph": batch_idx = 0 @@ -82,15 +97,15 @@ def _generate_tables(self, files): batch += f.readline() # finish current line batch = batch.split("\n\n") pa_table = pa.Table.from_arrays( - [pa.array([example for example in batch[:-1] if example])], schema=schema + [pa.array([example for example in batch[:-1] if example])], names=pa_table_names ) # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield (file_idx, batch_idx), pa_table + yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 batch = batch[-1] elif self.config.sample_by == "document": text = f.read() - pa_table = pa.Table.from_arrays([pa.array([text])], schema=schema) - yield file_idx, pa_table + pa_table = pa.Table.from_arrays([pa.array([text])], names=pa_table_names) + yield file_idx, self._cast_table(pa_table) diff --git a/tests/conftest.py b/tests/conftest.py index 5aa997c1c57..51d21be3499 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -470,16 +470,16 @@ def text_path_with_unicode_new_lines(tmp_path_factory): @pytest.fixture(scope="session") -def image_path(): +def image_file(): return os.path.join(os.path.dirname(__file__), "features", "data", "test_image_rgb.jpg") @pytest.fixture(scope="session") -def zip_image_path(image_path, tmp_path_factory): +def zip_image_path(image_file, tmp_path_factory): import zipfile path = tmp_path_factory.mktemp("data") / "dataset.img.zip" with zipfile.ZipFile(path, "w") as f: - f.write(image_path, arcname=os.path.basename(image_path)) - f.write(image_path, arcname=os.path.basename(image_path).replace(".jpg", "2.jpg")) + f.write(image_file, arcname=os.path.basename(image_file)) + f.write(image_file, arcname=os.path.basename(image_file).replace(".jpg", "2.jpg")) return path diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 54b4a6cb578..a800ba3419f 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -284,7 +284,7 @@ def test_classlabel_str2int(): classlabel = ClassLabel(names=names) for label in names: assert classlabel.str2int(label) == names.index(label) - with pytest.raises(KeyError): + with pytest.raises(ValueError): classlabel.str2int("__bad_label_name__") with pytest.raises(ValueError): classlabel.str2int(1) diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index f9473b6a799..e9e8e6f9f31 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -1,10 +1,14 @@ import os import textwrap +import pyarrow as pa import pytest +from datasets import ClassLabel, Features, Image from datasets.packaged_modules.csv.csv import Csv +from ..utils import require_pil + @pytest.fixture def csv_file(tmp_path): @@ -36,6 +40,36 @@ def malformed_csv_file(tmp_path): return str(filename) +@pytest.fixture +def csv_file_with_image(tmp_path, image_file): + filename = tmp_path / "csv_with_image.csv" + data = textwrap.dedent( + f"""\ + image + {image_file} + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def csv_file_with_label(tmp_path): + filename = tmp_path / "csv_with_label.csv" + data = textwrap.dedent( + """\ + label + good + bad + good + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog): csv = Csv() generator = csv._generate_tables([csv_file, malformed_csv_file]) @@ -48,3 +82,27 @@ def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed and os.path.basename(malformed_csv_file) in record.message for record in caplog.records ) + + +@require_pil +def test_csv_cast_image(csv_file_with_image): + with open(csv_file_with_image, encoding="utf-8") as f: + image_file = f.read().splitlines()[1] + csv = Csv(encoding="utf-8", features=Features({"image": Image()})) + generator = csv._generate_tables([csv_file_with_image]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.schema.field("image").type == Image()() + generated_content = pa_table.to_pydict()["image"] + assert generated_content == [{"path": image_file, "bytes": None}] + + +@require_pil +def test_csv_cast_label(csv_file_with_label): + with open(csv_file_with_label, encoding="utf-8") as f: + labels = f.read().splitlines()[1:] + csv = Csv(encoding="utf-8", features=Features({"label": ClassLabel(names=["good", "bad"])})) + generator = csv._generate_tables([csv_file_with_label]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])() + generated_content = pa_table.to_pydict()["label"] + assert generated_content == [ClassLabel(names=["good", "bad"]).str2int(label) for label in labels] diff --git a/tests/packaged_modules/test_imagefolder.py b/tests/packaged_modules/test_imagefolder.py index e6663c4361c..de725ad3623 100644 --- a/tests/packaged_modules/test_imagefolder.py +++ b/tests/packaged_modules/test_imagefolder.py @@ -1,4 +1,3 @@ -import os import shutil import textwrap @@ -18,11 +17,6 @@ def cache_dir(tmp_path): return str(tmp_path / "imagefolder_cache_dir") -@pytest.fixture -def image_file(): - return os.path.join(os.path.dirname(__file__), "..", "features", "data", "test_image_rgb.jpg") - - @pytest.fixture def image_file_with_metadata(tmp_path, image_file): image_filename = tmp_path / "image_rgb.jpg" diff --git a/tests/packaged_modules/test_text.py b/tests/packaged_modules/test_text.py index bc554feb72b..38830b20e4d 100644 --- a/tests/packaged_modules/test_text.py +++ b/tests/packaged_modules/test_text.py @@ -3,8 +3,11 @@ import pyarrow as pa import pytest +from datasets import Features, Image from datasets.packaged_modules.text.text import Text +from ..utils import require_pil + @pytest.fixture def text_file(tmp_path): @@ -22,6 +25,14 @@ def text_file(tmp_path): return str(filename) +@pytest.fixture +def text_file_with_image(tmp_path, image_file): + filename = tmp_path / "text_with_image.txt" + with open(filename, "w", encoding="utf-8") as f: + f.write(image_file) + return str(filename) + + @pytest.mark.parametrize("keep_linebreaks", [True, False]) def test_text_linebreaks(text_file, keep_linebreaks): with open(text_file, encoding="utf-8") as f: @@ -30,3 +41,15 @@ def test_text_linebreaks(text_file, keep_linebreaks): generator = text._generate_tables([text_file]) generated_content = pa.concat_tables([table for _, table in generator]).to_pydict()["text"] assert generated_content == expected_content + + +@require_pil +def test_text_cast_image(text_file_with_image): + with open(text_file_with_image, encoding="utf-8") as f: + image_file = f.read().splitlines()[0] + text = Text(encoding="utf-8", features=Features({"image": Image()})) + generator = text._generate_tables([text_file_with_image]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.schema.field("image").type == Image()() + generated_content = pa_table.to_pydict()["image"] + assert generated_content == [{"path": image_file, "bytes": None}] diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 96fd19a83c1..00f1110009b 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2691,10 +2691,10 @@ def test_dataset_add_column(column, expected_dtype, in_memory, transform, datase @pytest.mark.parametrize( "item", [ - {"col_1": "4", "col_2": 4, "col_3": 4.0}, - {"col_1": "4", "col_2": "4", "col_3": "4"}, - {"col_1": 4, "col_2": 4, "col_3": 4}, - {"col_1": 4.0, "col_2": 4.0, "col_3": 4.0}, + {"col_1": "2", "col_2": 2, "col_3": 2.0}, + {"col_1": "2", "col_2": "2", "col_3": "2"}, + {"col_1": 2, "col_2": 2, "col_3": 2}, + {"col_1": 2.0, "col_2": 2.0, "col_3": 2.0}, ], ) def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform):