diff --git a/src/nlp/arrow_dataset.py b/src/nlp/arrow_dataset.py index 9d4f5bfd0c6..f3d80932821 100644 --- a/src/nlp/arrow_dataset.py +++ b/src/nlp/arrow_dataset.py @@ -38,7 +38,7 @@ from .info import DatasetInfo from .search import IndexableMixin from .splits import NamedSplit -from .utils import map_all_sequences_to_lists, map_nested +from .utils import map_nested logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ class DatasetInfoMixin(object): at the base level of the Dataset for easy access. """ - def __init__(self, info: Optional[DatasetInfo], split: Optional[NamedSplit]): + def __init__(self, info: DatasetInfo, split: Optional[NamedSplit]): self._info = info self._split = split @@ -92,7 +92,7 @@ def download_size(self) -> Optional[int]: return self._info.download_size @property - def features(self): + def features(self) -> Features: return self._info.features @property @@ -131,6 +131,7 @@ def __init__( info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, ): + info = info.copy() if info is not None else DatasetInfo() DatasetInfoMixin.__init__(self, info=info, split=split) IndexableMixin.__init__(self) self._data: pa.Table = arrow_table @@ -139,6 +140,15 @@ def __init__( self._format_kwargs: dict = {} self._format_columns: Optional[list] = None self._output_all_columns: bool = False + inferred_features = Features.from_arrow_schema(arrow_table.schema) + if self.info.features is not None: + if self.info.features.type != inferred_features.type: + self.info.features = inferred_features + else: + pass # keep the original features + else: + self.info.features = inferred_features + assert self.features is not None, "Features can't be None in a Dataset object" @classmethod def from_file( @@ -177,7 +187,7 @@ def from_pandas( Be aware that Series of the `object` dtype don't carry enough information to always lead to a meaningful Arrow type. In the case that we cannot infer a type, e.g. because the DataFrame is of length 0 or the Series only contains None/nan objects, the type is set to - null. This behavior can be avoided by constructing an explicit schema and passing it to this function. + null. This behavior can be avoided by constructing explicit features and passing it to this function. Args: df (:obj:``pandas.DataFrame``): the dataframe that contains the dataset. @@ -186,21 +196,16 @@ def from_pandas( description, citation, etc. split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split. """ - if info is None: - info = DatasetInfo() - if info.features is None: - info.features = features - elif info.features != features and features is not None: + if info is not None and features is not None and info.features != features: raise ValueError( "Features specified in `features` and `info.features` can't be different:\n{}\n{}".format( features, info.features ) ) + features = features if features is not None else info.feature if info is not None else None pa_table: pa.Table = pa.Table.from_pandas( - df=df, schema=pa.schema(info.features.type) if info.features is not None else None + df=df, schema=pa.schema(features.type) if features is not None else None ) - if info.features is None: - info.features = Features.from_arrow_schema(pa_table.schema) return cls(pa_table, info=info, split=split) @classmethod @@ -221,21 +226,16 @@ def from_dict( description, citation, etc. split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split. """ - if info is None: - info = DatasetInfo() - if info.features is None: - info.features = features - elif info.features != features and features is not None: + if info is not None and features is not None and info.features != features: raise ValueError( "Features specified in `features` and `info.features` can't be different:\n{}\n{}".format( features, info.features ) ) + features = features if features is not None else info.feature if info is not None else None pa_table: pa.Table = pa.Table.from_pydict( - mapping=mapping, schema=pa.schema(info.features.type) if info.features is not None else None + mapping=mapping, schema=pa.schema(features.type) if features is not None else None ) - if info.features is None: - info.features = Features.from_arrow_schema(pa_table.schema) return cls(pa_table, info=info, split=split) @property @@ -277,14 +277,6 @@ def column_names(self) -> List[str]: """Names of the columns in the dataset. """ return self._data.column_names - @property - def schema(self) -> pa.Schema: - """The Arrow schema of the Apache Arrow table backing the dataset. - You probably don't need to access directly this and can rather use - :func:`nlp.Dataset.features` to inspect the dataset features. - """ - return self._data.schema - @property def shape(self): """Shape of the dataset (number of columns, number of rows).""" @@ -340,6 +332,7 @@ def dictionary_encode_column(self, column: str): casted_field = pa.field(field.name, pa.dictionary(pa.int32(), field.type), nullable=False) casted_schema.set(field_index, casted_field) self._data = self._data.cast(casted_schema) + self.info.features = Features.from_arrow_schema(self._data.schema) def flatten(self, max_depth=16): """ Flatten the Table. @@ -352,7 +345,7 @@ def flatten(self, max_depth=16): else: break if self.info is not None: - self.info.features = Features.from_arrow_schema(self.schema) + self.info.features = Features.from_arrow_schema(self._data.schema) logger.info( "Flattened dataset from depth {} to depth {}.".format(depth, 1 if depth + 1 < max_depth else "unknown") ) @@ -380,8 +373,7 @@ def __iter__(self): ) def __repr__(self): - schema_str = dict((a, str(b)) for a, b in zip(self._data.schema.names, self._data.schema.types)) - return f"Dataset(schema: {schema_str}, num_rows: {self.num_rows})" + return f"Dataset(features: {self.features}, num_rows: {self.num_rows})" @property def format(self): @@ -685,7 +677,7 @@ def map( load_from_cache_file: bool = True, cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, - arrow_schema: Optional[pa.Schema] = None, + features: Optional[Features] = None, disable_nullable: bool = True, verbose: bool = True, ): @@ -712,7 +704,7 @@ def map( results of the computation instead of the automatically generated cache file name. `writer_batch_size` (`int`, default: `1000`): Number of rows per write operation for the cache file writer. Higher value gives smaller cache files, lower value consume less temporary memory while running `.map()`. - `arrow_schema` (`Optional[pa.Schema]`, default: `None`): Use a specific Apache Arrow Schema to store the cache file + `features` (`Optional[nlp.Features]`, default: `None`): Use a specific Features to store the cache file instead of the automatically generated one. `disable_nullable` (`bool`, default: `True`): Allow null values in the table. `verbose` (`bool`, default: `True`): Set to `False` to deactivate the tqdm progress bar and informations. @@ -792,18 +784,6 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F inputs.update(processed_inputs) return inputs - # Find the output schema if none is given - test_inputs = self[:2] if batched else self[0] - test_indices = [0, 1] if batched else 0 - test_output = apply_function_on_filtered_inputs(test_inputs, test_indices) - if arrow_schema is None and update_data: - if not batched: - test_output = self._nest(test_output) - test_output = map_all_sequences_to_lists(test_output) - arrow_schema = pa.Table.from_pydict(test_output).schema - if disable_nullable: - arrow_schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in arrow_schema) - # Check if we've already cached this computation (indexed by a hash) if self._data_files and update_data: if cache_file_name is None: @@ -817,7 +797,7 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F "load_from_cache_file": load_from_cache_file, "cache_file_name": cache_file_name, "writer_batch_size": writer_batch_size, - "arrow_schema": arrow_schema, + "features": features, "disable_nullable": disable_nullable, } cache_file_name = self._get_cache_file_path(function, cache_kwargs) @@ -830,12 +810,12 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F if update_data: if keep_in_memory or not self._data_files: buf_writer = pa.BufferOutputStream() - writer = ArrowWriter(schema=arrow_schema, stream=buf_writer, writer_batch_size=writer_batch_size) + writer = ArrowWriter(features=features, stream=buf_writer, writer_batch_size=writer_batch_size) else: buf_writer = None if verbose: logger.info("Caching processed dataset at %s", cache_file_name) - writer = ArrowWriter(schema=arrow_schema, path=cache_file_name, writer_batch_size=writer_batch_size) + writer = ArrowWriter(features=features, path=cache_file_name, writer_batch_size=writer_batch_size) # Loop over single examples or batches and write to buffer/file if examples are to be updated if not batched: @@ -928,15 +908,8 @@ def map_function(batch, *args): return result - # to avoid errors with the arrow_schema we define it here - test_inputs = self[:2] - if "remove_columns" in kwargs: - test_inputs = {key: test_inputs[key] for key in (test_inputs.keys() - kwargs["remove_columns"])} - test_inputs = map_all_sequences_to_lists(test_inputs) - arrow_schema = pa.Table.from_pydict(test_inputs).schema - # return map function - return self.map(map_function, batched=True, with_indices=with_indices, arrow_schema=arrow_schema, **kwargs) + return self.map(map_function, batched=True, with_indices=with_indices, features=self.features, **kwargs) def select( self, @@ -991,12 +964,12 @@ def select( # Prepare output buffer and batched writer in memory or on file if we update the table if keep_in_memory or not self._data_files: buf_writer = pa.BufferOutputStream() - writer = ArrowWriter(schema=self.schema, stream=buf_writer, writer_batch_size=writer_batch_size) + writer = ArrowWriter(features=self.features, stream=buf_writer, writer_batch_size=writer_batch_size) else: buf_writer = None if verbose: logger.info("Caching processed dataset at %s", cache_file_name) - writer = ArrowWriter(schema=self.schema, path=cache_file_name, writer_batch_size=writer_batch_size) + writer = ArrowWriter(features=self.features, path=cache_file_name, writer_batch_size=writer_batch_size) # Loop over batches and write to buffer/file if examples are to be updated for i in tqdm(range(0, len(indices), reader_batch_size), disable=not verbose): diff --git a/src/nlp/arrow_writer.py b/src/nlp/arrow_writer.py index 7834fba61fc..e2c6a3dedde 100644 --- a/src/nlp/arrow_writer.py +++ b/src/nlp/arrow_writer.py @@ -23,6 +23,7 @@ import pyarrow as pa +from .features import Features from .utils.file_utils import HF_DATASETS_CACHE, hash_url_to_filename from .utils.py_utils import map_all_sequences_to_lists @@ -42,6 +43,7 @@ def __init__( self, data_type: Optional[pa.DataType] = None, schema: Optional[pa.Schema] = None, + features: Optional[Features] = None, path: Optional[str] = None, stream: Optional[pa.NativeFile] = None, writer_batch_size: Optional[int] = None, @@ -49,20 +51,27 @@ def __init__( ): if path is None and stream is None: raise ValueError("At least one of path and stream must be provided.") - - if data_type is not None: + if features is not None: + self._features = features + self._schema = pa.schema(features.type) if features is not None else None + self._type: pa.DataType = pa.struct(field for field in self._schema) + elif data_type is not None: self._type: pa.DataType = data_type self._schema: pa.Schema = pa.schema(field for field in self._type) + self._features = Features.from_arrow_schema(self._schema) elif schema is not None: self._schema: pa.Schema = schema self._type: pa.DataType = pa.struct(field for field in self._schema) + self._features = Features.from_arrow_schema(self._schema) else: + self._features = None self._schema = None self._type = None if disable_nullable and self._schema is not None: self._schema = pa.schema(pa.field(field.name, field.type, nullable=False) for field in self._type) self._type = pa.struct(pa.field(field.name, field.type, nullable=False) for field in self._type) + self._features = Features.from_arrow_schema(self._schema) self._path = path if stream is None: @@ -76,19 +85,15 @@ def __init__( self._num_bytes = 0 self.current_rows = [] - self._build_writer(schema=self._schema) + self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None + if self._schema is not None: + self._build_writer(schema=self._schema) - def _build_writer(self, pa_table=None, schema=None): - if schema is not None: - self._schema: pa.Schema = schema - self._type: pa.DataType = pa.struct(field for field in self._schema) - self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema) - elif pa_table is not None: - self._schema: pa.Schema = pa_table.schema - self._type: pa.DataType = pa.struct(field for field in self._schema) - self.pa_writer = pa.RecordBatchStreamWriter(self.stream, self._schema) - else: - self.pa_writer = None + def _build_writer(self, schema: pa.Schema): + self._schema: pa.Schema = schema + self._type: pa.DataType = pa.struct(field for field in self._schema) + self._features = Features.from_arrow_schema(self._schema) + self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema) @property def schema(self): @@ -98,6 +103,9 @@ def _write_array_on_file(self, pa_array): """Write a PyArrow Array""" pa_batch = pa.RecordBatch.from_struct_array(pa_array) self._num_bytes += pa_array.nbytes + if self.pa_writer is None: + pa_table = pa.Table.from_batches([pa_batch]) + self._build_writer(schema=pa_table.schema) self.pa_writer.write_batch(pa_batch) def write_on_file(self): @@ -141,8 +149,6 @@ def write(self, example: Dict[str, Any], writer_batch_size: Optional[int] = None self._num_examples += 1 if writer_batch_size is None: writer_batch_size = self.writer_batch_size - if self.pa_writer is None: - self._build_writer(pa_table=pa.Table.from_pydict(example)) if writer_batch_size is not None and len(self.current_rows) >= writer_batch_size: self.write_on_file() @@ -156,7 +162,7 @@ def write_batch( """ batch_examples = map_all_sequences_to_lists(batch_examples) if self.pa_writer is None: - self._build_writer(pa_table=pa.Table.from_pydict(batch_examples)) + self._build_writer(schema=pa.Table.from_pydict(batch_examples).schema) pa_table: pa.Table = pa.Table.from_pydict(batch_examples, schema=self._schema) if writer_batch_size is None: writer_batch_size = self.writer_batch_size @@ -175,7 +181,7 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non if writer_batch_size is None: writer_batch_size = self.writer_batch_size if self.pa_writer is None: - self._build_writer(pa_table=pa_table) + self._build_writer(schema=pa_table.schema) batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size) self._num_bytes += sum(batch.nbytes for batch in batches) self._num_examples += pa_table.num_rows @@ -183,9 +189,8 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non self.pa_writer.write_batch(batch) def finalize(self, close_stream=True): - if self.pa_writer is not None: - self.write_on_file() - self.pa_writer.close() + self.write_on_file() + self.pa_writer.close() if close_stream: self.stream.close() logger.info( diff --git a/src/nlp/features.py b/src/nlp/features.py index 7f49c252648..414e2578a0f 100644 --- a/src/nlp/features.py +++ b/src/nlp/features.py @@ -55,6 +55,10 @@ class Value: _type: str = field(default="Value", init=False, repr=False) def __post_init__(self): + if self.dtype == "double": # fix inferred type + self.dtype = "float64" + if self.dtype == "float": # fix inferred type + self.dtype = "float32" self.pa_type = string_to_arrow(self.dtype) def __call__(self): diff --git a/src/nlp/info.py b/src/nlp/info.py index 0128253b617..0b5b564c884 100644 --- a/src/nlp/info.py +++ b/src/nlp/info.py @@ -29,6 +29,7 @@ - etc. """ +import copy import json import logging import os @@ -85,7 +86,7 @@ class DatasetInfo: citation: str = field(default_factory=str) homepage: str = field(default_factory=str) license: str = field(default_factory=str) - features: Features = None + features: Optional[Features] = None supervised_keys: Optional[SupervisedKeysData] = None # Set later by the builder @@ -161,9 +162,16 @@ def from_directory(cls, dataset_info_dir): def update(self, other_dataset_info, ignore_none=True): self_dict = self.__dict__ self_dict.update( - **{k: v for k, v in other_dataset_info.__dict__.items() if (v is not None or not ignore_none)} + **{ + k: copy.deepcopy(v) + for k, v in other_dataset_info.__dict__.items() + if (v is not None or not ignore_none) + } ) + def copy(self) -> "DatasetInfo": + return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()}) + class DatasetInfosDict(dict): def write_to_directory(self, dataset_infos_dir, overwrite=False): diff --git a/src/nlp/splits.py b/src/nlp/splits.py index 244cef5cb78..f86d96fda24 100644 --- a/src/nlp/splits.py +++ b/src/nlp/splits.py @@ -477,7 +477,6 @@ class SplitDict(dict): def __init__(self, *args, dataset_name=None, **kwargs): super(SplitDict, self).__init__(*args, **kwargs) - # super(SplitDict, self).__init__(error_msg="Split {key} already present", **kwargs) self.dataset_name = dataset_name def __getitem__(self, key: Union[SplitBase, str]): @@ -490,14 +489,16 @@ def __getitem__(self, key: Union[SplitBase, str]): return SubSplitInfo(instructions) def __setitem__(self, key: Union[SplitBase, str], value: SplitInfo): - raise ValueError("Cannot add elem. Use .add() instead.") + if key != value.name: + raise ValueError("Cannot add elem. (key mismatch: '{}' != '{}')".format(key, value.name)) + if key in self: + raise ValueError("Split {} already present".format(key)) + super(SplitDict, self).__setitem__(key, value) def add(self, split_info: SplitInfo): """Add the split info.""" if split_info.name in self: raise ValueError("Split {} already present".format(split_info.name)) - # Forward the dataset name required to build file instructions: - # info.splits['train'].file_instructions split_info.dataset_name = self.dataset_name super(SplitDict, self).__setitem__(split_info.name, split_info) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index fdf6d41ed38..b098c0c49e2 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -7,45 +7,20 @@ import pyarrow as pa from nlp.arrow_dataset import Dataset -from nlp.arrow_reader import BaseReader from nlp.features import Features, Sequence, Value -from nlp.info import DatasetInfo -from nlp.splits import SplitDict, SplitInfo - - -class ReaderTester(BaseReader): - """ - Build a Dataset object out of Instruction instance(s). - This reader is made for testing. It mocks file reads. - """ - - def _get_dataset_from_filename(self, filename_skip_take): - """Returns a Dataset instance from given (filename, skip, take).""" - filename, skip, take = ( - filename_skip_take["filename"], - filename_skip_take["skip"] if "skip" in filename_skip_take else None, - filename_skip_take["take"] if "take" in filename_skip_take else None, - ) - pa_table = pa.Table.from_pydict({"filename": [filename + "_" + str(x) for x in np.arange(30).tolist()]}) - if skip is not None and take is not None: - pa_table = pa_table.slice(skip, take) - return pa_table class BaseDatasetTest(TestCase): def _create_dummy_dataset(self): - name = "my_name" - train_info = SplitInfo(name="train", num_examples=30) - test_info = SplitInfo(name="test", num_examples=30) - split_infos = [train_info, test_info] - split_dict = SplitDict() - split_dict.add(train_info) - split_dict.add(test_info) - info = DatasetInfo(splits=split_dict) - reader = ReaderTester("", info) - dset = Dataset(**reader.read(name, "train", split_infos)) + dset = Dataset( + pa.Table.from_pydict({"filename": ["my_name-train" + "_" + str(x) for x in np.arange(30).tolist()]}) + ) return dset + def test_dummy_dataset(self): + dset = self._create_dummy_dataset() + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + def test_from_pandas(self): data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} df = pd.DataFrame.from_dict(data) @@ -53,12 +28,14 @@ def test_from_pandas(self): self.assertListEqual(dset["col_1"], data["col_1"]) self.assertListEqual(dset["col_2"], data["col_2"]) self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")})) features = Features({"col_1": Value("int64"), "col_2": Value("string")}) dset = Dataset.from_pandas(df, features=features) self.assertListEqual(dset["col_1"], data["col_1"]) self.assertListEqual(dset["col_2"], data["col_2"]) self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")})) features = Features({"col_1": Value("string"), "col_2": Value("string")}) self.assertRaises(pa.ArrowTypeError, Dataset.from_pandas, df, features=features) @@ -69,12 +46,14 @@ def test_from_dict(self): self.assertListEqual(dset["col_1"], data["col_1"]) self.assertListEqual(dset["col_2"], data["col_2"]) self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")})) features = Features({"col_1": Value("int64"), "col_2": Value("string")}) dset = Dataset.from_dict(data, features=features) self.assertListEqual(dset["col_1"], data["col_1"]) self.assertListEqual(dset["col_2"], data["col_2"]) self.assertListEqual(list(dset.features.keys()), ["col_1", "col_2"]) + self.assertDictEqual(dset.features, Features({"col_1": Value("int64"), "col_2": Value("string")})) features = Features({"col_1": Value("string"), "col_2": Value("string")}) self.assertRaises(pa.ArrowTypeError, Dataset.from_dict, data, features=features) @@ -87,25 +66,34 @@ def test_flatten(self): dset.flatten() self.assertListEqual(dset.column_names, ["a.b.c", "foo"]) self.assertListEqual(list(dset.features.keys()), ["a.b.c", "foo"]) + self.assertDictEqual(dset.features, Features({"a.b.c": [Value("string")], "foo": Value("int64")})) def test_map(self): dset = self._create_dummy_dataset() with tempfile.TemporaryDirectory() as tmp_dir: tmp_file = os.path.join(tmp_dir, "test.arrow") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) dset_test = dset.map( lambda x: {"name": x["filename"][:-2], "id": int(x["filename"][-1])}, cache_file_name=tmp_file ) self.assertEqual(len(dset_test), 30) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual( + dset_test.features, + Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}), + ) with tempfile.TemporaryDirectory() as tmp_dir: tmp_file = os.path.join(tmp_dir, "test.arrow") - dset_test = dset.map( - lambda x: {"name": x["filename"][:-2], "id": int(x["filename"][-1])}, cache_file_name=tmp_file - ) dset_test_with_indices = dset.map( lambda x, i: {"name": x["filename"][:-2], "id": i}, with_indices=True, cache_file_name=tmp_file ) self.assertEqual(len(dset_test_with_indices), 30) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual( + dset_test_with_indices.features, + Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}), + ) def test_map_batched(self): dset = self._create_dummy_dataset() @@ -117,6 +105,10 @@ def map_batched(example): tmp_file = os.path.join(tmp_dir, "test.arrow") dset_test_batched = dset.map(map_batched, batched=True, cache_file_name=tmp_file) self.assertEqual(len(dset_test_batched), 30) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual( + dset_test_batched.features, Features({"filename": Value("string"), "filename_new": Value("string")}) + ) def map_batched_with_indices(example, idx): return {"filename_new": [x + "_extension_" + str(idx) for x in example["filename"]]} @@ -127,6 +119,11 @@ def map_batched_with_indices(example, idx): map_batched_with_indices, batched=True, with_indices=True, cache_file_name=tmp_file ) self.assertEqual(len(dset_test_with_indices_batched), 30) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual( + dset_test_with_indices_batched.features, + Features({"filename": Value("string"), "filename_new": Value("string")}), + ) def test_remove_colums(self): dset = self._create_dummy_dataset() @@ -137,11 +134,15 @@ def test_remove_colums(self): lambda x, i: {"name": x["filename"][:-2], "id": i}, with_indices=True, cache_file_name=tmp_file ) self.assertTrue("id" in dset[0]) + self.assertDictEqual( + dset.features, Features({"filename": Value("string"), "name": Value("string"), "id": Value("int64")}) + ) with tempfile.TemporaryDirectory() as tmp_dir: tmp_file = os.path.join(tmp_dir, "test.arrow") dset = dset.map(lambda x: x, remove_columns=["id"], cache_file_name=tmp_file) self.assertTrue("id" not in dset[0]) + self.assertDictEqual(dset.features, Features({"filename": Value("string"), "name": Value("string")})) def test_filter(self): dset = self._create_dummy_dataset() @@ -151,12 +152,16 @@ def test_filter(self): tmp_file = os.path.join(tmp_dir, "test.arrow") dset_filter_first_five = dset.filter(lambda x, i: i < 5, with_indices=True, cache_file_name=tmp_file) self.assertEqual(len(dset_filter_first_five), 5) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_filter_first_five.features, Features({"filename": Value("string")})) # filter filenames with even id at the end with tempfile.TemporaryDirectory() as tmp_dir: tmp_file = os.path.join(tmp_dir, "test.arrow") dset_filter_even_num = dset.filter(lambda x: (int(x["filename"][-1]) % 2 == 0), cache_file_name=tmp_file) self.assertEqual(len(dset_filter_even_num), 15) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_filter_even_num.features, Features({"filename": Value("string")})) def test_select(self): dset = self._create_dummy_dataset() @@ -169,6 +174,8 @@ def test_select(self): self.assertEqual(len(dset_select_even), 15) for row in dset_select_even: self.assertEqual(int(row["filename"][-1]) % 2, 0) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_select_even.features, Features({"filename": Value("string")})) def test_shuffle(self): dset = self._create_dummy_dataset() @@ -179,8 +186,10 @@ def test_shuffle(self): self.assertEqual(len(dset_shuffled), 30) self.assertEqual(dset_shuffled[0]["filename"], "my_name-train_28") self.assertEqual(dset_shuffled[2]["filename"], "my_name-train_10") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_shuffled.features, Features({"filename": Value("string")})) - # Reproductibility + # Reproducibility tmp_file = os.path.join(tmp_dir, "test_2.arrow") dset_shuffled_2 = dset.shuffle(seed=1234, cache_file_name=tmp_file) self.assertListEqual(dset_shuffled["filename"], dset_shuffled_2["filename"]) @@ -202,11 +211,15 @@ def test_sort(self): dset_sorted = dset.sort("filename", cache_file_name=tmp_file) for i, row in enumerate(dset_sorted): self.assertEqual(int(row["filename"][-1]), i) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")})) # Sort reversed tmp_file = os.path.join(tmp_dir, "test_4.arrow") dset_sorted = dset.sort("filename", cache_file_name=tmp_file, reverse=True) for i, row in enumerate(dset_sorted): self.assertEqual(int(row["filename"][-1]), len(dset_sorted) - 1 - i) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")})) def test_train_test_split(self): dset = self._create_dummy_dataset() @@ -227,6 +240,9 @@ def test_train_test_split(self): self.assertEqual(dset_train[-1]["filename"], "my_name-train_19") self.assertEqual(dset_test[0]["filename"], "my_name-train_20") self.assertEqual(dset_test[-1]["filename"], "my_name-train_29") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_train.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_test.features, Features({"filename": Value("string")})) tmp_file = os.path.join(tmp_dir, "test_3.arrow") tmp_file_2 = os.path.join(tmp_dir, "test_4.arrow") @@ -243,6 +259,9 @@ def test_train_test_split(self): self.assertEqual(dset_train[-1]["filename"], "my_name-train_14") self.assertEqual(dset_test[0]["filename"], "my_name-train_15") self.assertEqual(dset_test[-1]["filename"], "my_name-train_29") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_train.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_test.features, Features({"filename": Value("string")})) tmp_file = os.path.join(tmp_dir, "test_5.arrow") tmp_file_2 = os.path.join(tmp_dir, "test_6.arrow") @@ -259,6 +278,9 @@ def test_train_test_split(self): self.assertEqual(dset_train[-1]["filename"], "my_name-train_9") self.assertEqual(dset_test[0]["filename"], "my_name-train_10") self.assertEqual(dset_test[-1]["filename"], "my_name-train_29") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_train.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_test.features, Features({"filename": Value("string")})) tmp_file = os.path.join(tmp_dir, "test_7.arrow") tmp_file_2 = os.path.join(tmp_dir, "test_8.arrow") @@ -275,6 +297,9 @@ def test_train_test_split(self): self.assertNotEqual(dset_train[-1]["filename"], "my_name-train_9") self.assertNotEqual(dset_test[0]["filename"], "my_name-train_10") self.assertNotEqual(dset_test[-1]["filename"], "my_name-train_29") + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_train.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_test.features, Features({"filename": Value("string")})) def test_shard(self): dset = self._create_dummy_dataset() @@ -287,9 +312,13 @@ def test_shard(self): dset_sharded = dset.shard(num_shards=8, index=1) self.assertEqual(2, len(dset_sharded)) self.assertEqual(["my_name-train_1", "my_name-train_9"], dset_sharded["filename"]) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_sharded.features, Features({"filename": Value("string")})) # Shard contiguous dset_sharded_contiguous = dset.shard(num_shards=3, index=0, contiguous=True) self.assertEqual([f"my_name-train_{i}" for i in (0, 1, 2, 3)], dset_sharded_contiguous["filename"]) + self.assertDictEqual(dset.features, Features({"filename": Value("string")})) + self.assertDictEqual(dset_sharded.features, Features({"filename": Value("string")})) # Test lengths of sharded contiguous self.assertEqual([4, 3, 3], [len(dset.shard(3, index=i, contiguous=True)) for i in range(3)]) @@ -309,6 +338,7 @@ def test_format_vectors(self): for col in columns: self.assertIsInstance(dset[0][col], (str, list)) self.assertIsInstance(dset[:2][col], list) + self.assertDictEqual(dset.features, Features({"filename": Value("string"), "vec": [Value("float64")]})) # don't test if torch and tensorflow are stacked accross examples # we need to use the features definition to know at what depth we have to to the conversion @@ -348,6 +378,9 @@ def test_format_nested(self): cache_file_name=tmp_file, batched=True, ) + self.assertDictEqual( + dset.features, Features({"filename": Value("string"), "nested": {"foo": [Value("float64")]}}) + ) dset.set_format("tensorflow") self.assertIsNotNone(dset[0]) diff --git a/tests/test_features.py b/tests/test_features.py index 628d5121145..4e08e407a55 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -9,8 +9,9 @@ def test_from_arrow_schema_simple(self): data = {"a": [{"b": {"c": "text"}}] * 10, "foo": [1] * 10} original_features = Features({"a": {"b": {"c": Value("string")}}, "foo": Value("int64")}) dset = Dataset.from_dict(data, features=original_features) - new_features = Features.from_arrow_schema(dset.schema) + new_features = dset.features new_dset = Dataset.from_dict(data, features=new_features) + self.assertEqual(original_features.type, new_features.type) self.assertDictEqual(dset[0], new_dset[0]) self.assertDictEqual(dset[:], new_dset[:]) @@ -18,7 +19,8 @@ def test_from_arrow_schema_with_sequence(self): data = {"a": [{"b": {"c": ["text"]}}] * 10, "foo": [1] * 10} original_features = Features({"a": {"b": Sequence({"c": Value("string")})}, "foo": Value("int64")}) dset = Dataset.from_dict(data, features=original_features) - new_features = Features.from_arrow_schema(dset.schema) + new_features = dset.features new_dset = Dataset.from_dict(data, features=new_features) + self.assertEqual(original_features.type, new_features.type) self.assertDictEqual(dset[0], new_dset[0]) self.assertDictEqual(dset[:], new_dset[:])