diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 6e304f6a9e7..feb13d9c8fa 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -3,7 +3,8 @@ import numpy as np import datasets -from datasets.features import _ArrayXD +from datasets.arrow_writer import ArrowWriter +from datasets.features.features import _ArrayXD def get_duration(func): @@ -46,7 +47,7 @@ def generate_examples(features: dict, num_examples=100, seq_shapes=None): def generate_example_dataset(dataset_path, features, num_examples=100, seq_shapes=None): dummy_data = generate_examples(features, num_examples=num_examples, seq_shapes=seq_shapes) - with datasets.ArrowWriter(features=features, path=dataset_path) as writer: + with ArrowWriter(features=features, path=dataset_path) as writer: for key, record in dummy_data: example = features.encode_example(record) writer.write(example) diff --git a/metrics/perplexity/perplexity.py b/metrics/perplexity/perplexity.py index 427a8896ef5..0c6b9ead819 100644 --- a/metrics/perplexity/perplexity.py +++ b/metrics/perplexity/perplexity.py @@ -17,7 +17,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import datasets -from datasets.utils import tqdm +from datasets import utils _CITATION = """\ @@ -113,7 +113,7 @@ def _compute(self, input_texts, model_id, stride=512, device=None): ppls = [] - for text_index in tqdm(range(0, len(encoded_texts))): + for text_index in utils.tqdm_utils.tqdm(range(0, len(encoded_texts))): encoded_text = encoded_texts[text_index] special_tokens_mask = special_tokens_masks[text_index] diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py index fbfee81e2a1..85825769559 100644 --- a/src/datasets/__init__.py +++ b/src/datasets/__init__.py @@ -20,36 +20,26 @@ __version__ = "1.18.5.dev0" import pyarrow -from packaging import version as _version -from pyarrow import total_allocated_bytes +from packaging import version -if _version.parse(pyarrow.__version__).major < 5: +if version.parse(pyarrow.__version__).major < 5: raise ImportWarning( "To use `datasets`, the module `pyarrow>=5.0.0` is required, and the current version of `pyarrow` doesn't match this condition.\n" "If you are running this in a Google Colab, you should probably just restart the runtime to use the right version of `pyarrow`." ) +SCRIPTS_VERSION = "master" if version.parse(__version__).is_devrelease else __version__ + +del pyarrow +del version + from .arrow_dataset import Dataset, concatenate_datasets -from .arrow_reader import ArrowReader, ReadInstruction -from .arrow_writer import ArrowWriter +from .arrow_reader import ReadInstruction from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder from .combine import interleave_datasets from .dataset_dict import DatasetDict, IterableDatasetDict -from .features import ( - Array2D, - Array3D, - Array4D, - Array5D, - Audio, - ClassLabel, - Features, - Image, - Sequence, - Translation, - TranslationVariableLanguages, - Value, -) +from .features import * from .fingerprint import is_caching_enabled, set_caching_enabled from .info import DatasetInfo, MetricInfo from .inspect import ( @@ -63,8 +53,7 @@ list_metrics, ) from .iterable_dataset import IterableDataset -from .keyhash import KeyHasher -from .load import import_main_class, load_dataset, load_dataset_builder, load_from_disk, load_metric +from .load import load_dataset, load_dataset_builder, load_from_disk, load_metric from .metric import Metric from .splits import ( NamedSplit, @@ -79,6 +68,4 @@ ) from .tasks import * from .utils import * - - -SCRIPTS_VERSION = "master" if _version.parse(__version__).is_devrelease else __version__ +from .utils import logging diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 1d190f2daae..91dd074f051 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -58,19 +58,8 @@ from . import config, utils from .arrow_reader import ArrowReader from .arrow_writer import ArrowWriter, OptimizedTypedSequence -from .features import ( - Audio, - ClassLabel, - Features, - FeatureType, - Image, - Sequence, - Value, - _ArrayXD, - decode_nested_example, - pandas_types_mapper, - require_decoding, -) +from .features import Audio, ClassLabel, Features, Image, Sequence, Value +from .features.features import FeatureType, _ArrayXD, decode_nested_example, pandas_types_mapper, require_decoding from .filesystems import extract_path_from_uri, is_remote_filesystem from .fingerprint import ( fingerprint_transform, @@ -2316,7 +2305,7 @@ def init_buffer_and_writer(): pbar_total = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size pbar_unit = "ex" if not batched else "ba" pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc - pbar = utils.tqdm( + pbar = utils.tqdm_utils.tqdm( pbar_iterable, total=pbar_total, disable=disable_tqdm, @@ -3465,7 +3454,7 @@ def delete_file(file): api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch) if len(file_shards_to_delete): - for file in utils.tqdm( + for file in utils.tqdm_utils.tqdm( file_shards_to_delete, desc="Deleting unused files from dataset repository", total=len(file_shards_to_delete), @@ -3474,7 +3463,7 @@ def delete_file(file): delete_file(file) uploaded_size = 0 - for index, shard in utils.tqdm( + for index, shard in utils.tqdm_utils.tqdm( enumerate(shards), desc="Pushing dataset shards to the dataset hub", total=num_shards, diff --git a/src/datasets/arrow_reader.py b/src/datasets/arrow_reader.py index 0ce6eea122b..9da2873ca10 100644 --- a/src/datasets/arrow_reader.py +++ b/src/datasets/arrow_reader.py @@ -26,11 +26,10 @@ import pyarrow as pa import pyarrow.parquet as pq -from datasets.utils.file_utils import DownloadConfig - from .naming import _split_re, filename_for_dataset_split from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables -from .utils import cached_path, logging +from .utils import logging +from .utils.file_utils import DownloadConfig, cached_path if TYPE_CHECKING: diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index e7ac88d9253..6cdc1c4a364 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -22,12 +22,10 @@ import numpy as np import pyarrow as pa -from datasets.features.features import FeatureType, Value - from . import config, utils -from .features import ( - Features, - Image, +from .features import Features, Image, Value +from .features.features import ( + FeatureType, _ArrayXDExtensionType, cast_to_python_objects, generate_from_arrow_type, @@ -641,9 +639,9 @@ def parquet_to_arrow(sources, destination): stream = None if isinstance(destination, str) else destination disable = not utils.is_progress_bar_enabled() with ArrowWriter(path=destination, stream=stream) as writer: - for source in utils.tqdm(sources, unit="sources", disable=disable): + for source in utils.tqdm_utils.tqdm(sources, unit="sources", disable=disable): pf = pa.parquet.ParquetFile(source) - for i in utils.tqdm(range(pf.num_row_groups), unit="row_groups", leave=False, disable=disable): + for i in utils.tqdm_utils.tqdm(range(pf.num_row_groups), unit="row_groups", leave=False, disable=disable): df = pf.read_row_group(i).to_pandas() for col in df.columns: df[col] = df[col].apply(json.loads) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7b05ca21b57..a5b098d9d51 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -27,9 +27,6 @@ from functools import partial from typing import Dict, Mapping, Optional, Tuple, Union -from datasets.features import Features -from datasets.utils.mock_download_manager import MockDownloadManager - from . import config, utils from .arrow_dataset import Dataset from .arrow_reader import ( @@ -42,6 +39,7 @@ from .arrow_writer import ArrowWriter, BeamWriter from .data_files import DataFilesDict, sanitize_patterns from .dataset_dict import DatasetDict, IterableDatasetDict +from .features import Features from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper @@ -49,9 +47,18 @@ from .splits import Split, SplitDict, SplitGenerator from .utils import logging from .utils.download_manager import DownloadManager, DownloadMode -from .utils.file_utils import DownloadConfig, is_remote_url +from .utils.file_utils import DownloadConfig, cached_path, is_remote_url from .utils.filelock import FileLock from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits +from .utils.mock_download_manager import MockDownloadManager +from .utils.py_utils import ( + classproperty, + has_sufficient_disk_space, + map_nested, + memoize, + size_str, + temporary_assignment, +) from .utils.streaming_download_manager import StreamingDownloadManager @@ -389,9 +396,9 @@ def _create_builder_config(self, name=None, custom_features=None, **config_kwarg return builder_config, config_id - @utils.classproperty + @classproperty @classmethod - @utils.memoize() + @memoize() def builder_configs(cls): """Pre-defined list of configurations for this builder class.""" configs = {config.name: config for config in cls.BUILDER_CONFIGS} @@ -537,9 +544,9 @@ def download_and_prepare( return logger.info(f"Generating dataset {self.name} ({self._cache_dir})") if not is_remote_url(self._cache_dir_root): # if cache dir is local, check for available space - if not utils.has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root): + if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root): raise OSError( - f"Not enough disk space. Needed: {utils.size_str(self.info.size_in_bytes or 0)} (download: {utils.size_str(self.info.download_size or 0)}, generated: {utils.size_str(self.info.dataset_size or 0)}, post-processed: {utils.size_str(self.info.post_processing_size or 0)})" + f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})" ) @contextlib.contextmanager @@ -565,9 +572,9 @@ def incomplete_dir(dirname): if self.info.size_in_bytes: print( f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} " - f"(download: {utils.size_str(self.info.download_size)}, generated: {utils.size_str(self.info.dataset_size)}, " - f"post-processed: {utils.size_str(self.info.post_processing_size)}, " - f"total: {utils.size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." + f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, " + f"post-processed: {size_str(self.info.post_processing_size)}, " + f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." ) else: print( @@ -580,7 +587,7 @@ def incomplete_dir(dirname): with incomplete_dir(self._cache_dir) as tmp_data_dir: # Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward # it to every sub function. - with utils.temporary_assignment(self, "_cache_dir", tmp_data_dir): + with temporary_assignment(self, "_cache_dir", tmp_data_dir): # Try to download the already prepared dataset files downloaded_from_gcs = False if try_from_hf_gcs: @@ -637,7 +644,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") try: - resource_path = utils.cached_path(remote_cache_dir + "/" + resource_file_name) + resource_path = cached_path(remote_cache_dir + "/" + resource_file_name) shutil.move(resource_path, os.path.join(self._cache_dir, resource_file_name)) except ConnectionError: logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") @@ -761,7 +768,7 @@ def as_dataset( split = {s: s for s in self.info.splits} # Create a dataset for each of the given splits - datasets = utils.map_nested( + datasets = map_nested( partial( self._build_single_dataset, run_post_process=run_post_process, @@ -903,7 +910,7 @@ def as_streaming_dataset( raise ValueError(f"Bad split: {split}. Available splits: {list(splits_generators)}") # Create a dataset for each of the given splits - datasets = utils.map_nested( + datasets = map_nested( self._as_streaming_dataset_single, splits_generator, map_tuple=True, @@ -1074,7 +1081,7 @@ def _prepare_split(self, split_generator): check_duplicates=True, ) as writer: try: - for key, record in utils.tqdm( + for key, record in utils.tqdm_utils.tqdm( generator, unit=" examples", total=split_info.num_examples, @@ -1135,7 +1142,7 @@ def _prepare_split(self, split_generator): generator = self._generate_tables(**split_generator.gen_kwargs) with ArrowWriter(features=self.info.features, path=fpath) as writer: - for key, table in utils.tqdm( + for key, table in utils.tqdm_utils.tqdm( generator, unit=" tables", leave=False, disable=True # not utils.is_progress_bar_enabled() ): writer.write_table(table) diff --git a/src/datasets/commands/dummy_data.py b/src/datasets/commands/dummy_data.py index 7030f182568..c29f515d8f0 100644 --- a/src/datasets/commands/dummy_data.py +++ b/src/datasets/commands/dummy_data.py @@ -11,10 +11,10 @@ from datasets import config from datasets.commands import BaseDatasetsCLICommand from datasets.load import dataset_module_factory, import_main_class -from datasets.utils import MockDownloadManager from datasets.utils.download_manager import DownloadManager from datasets.utils.file_utils import DownloadConfig from datasets.utils.logging import get_logger, set_verbosity_warning +from datasets.utils.mock_download_manager import MockDownloadManager from datasets.utils.py_utils import map_nested diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 0b3ba154201..38d002dfa7f 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -7,8 +7,7 @@ from fsspec.implementations.local import LocalFileSystem from tqdm.contrib.concurrent import thread_map -from datasets.filesystems.hffilesystem import HfFileSystem - +from .filesystems.hffilesystem import HfFileSystem from .splits import Split from .utils import logging from .utils.file_utils import hf_hub_url, is_remote_url, request_etag diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 23dfee60170..d8f72f31d6f 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -1,12 +1,20 @@ # flake8: noqa + +__all__ = [ + "Audio", + "Array2D", + "Array3D", + "Array4D", + "Array5D", + "ClassLabel", + "Features", + "Sequence", + "Value", + "Image", + "Translation", + "TranslationVariableLanguages", +] from .audio import Audio -from .features import * -from .features import ( - _ArrayXD, - _ArrayXDExtensionType, - _arrow_to_datasets_dtype, - _cast_to_python_objects, - _is_zero_copy_only, -) -from .image import Image, objects_to_list_of_image_dicts +from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, Sequence, Value +from .image import Image from .translation import Translation, TranslationVariableLanguages diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index d2f475b1728..88aa2f52d05 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -33,14 +33,15 @@ from pandas.api.extensions import ExtensionArray as PandasExtensionArray from pandas.api.extensions import ExtensionDtype as PandasExtensionDtype -from datasets import config, utils -from datasets.features.audio import Audio -from datasets.features.image import Image, encode_pil_image -from datasets.features.translation import Translation, TranslationVariableLanguages -from datasets.utils.logging import get_logger +from .. import config +from ..utils import logging +from ..utils.py_utils import zip_dict +from .audio import Audio +from .image import Image, encode_pil_image +from .translation import Translation, TranslationVariableLanguages -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str: @@ -969,9 +970,7 @@ def encode_nested_example(schema, obj): """ # Nested structures: we allow dict, list/tuples, sequences if isinstance(schema, dict): - return { - k: encode_nested_example(sub_schema, sub_obj) for k, (sub_schema, sub_obj) in utils.zip_dict(schema, obj) - } + return {k: encode_nested_example(sub_schema, sub_obj) for k, (sub_schema, sub_obj) in zip_dict(schema, obj)} elif isinstance(schema, (list, tuple)): sub_schema = schema[0] if obj is None: @@ -991,12 +990,12 @@ def encode_nested_example(schema, obj): list_dict = {} if isinstance(obj, (list, tuple)): # obj is a list of dict - for k, dict_tuples in utils.zip_dict(schema.feature, *obj): + for k, dict_tuples in zip_dict(schema.feature, *obj): list_dict[k] = [encode_nested_example(dict_tuples[0], o) for o in dict_tuples[1:]] return list_dict else: # obj is a single dict - for k, (sub_schema, sub_objs) in utils.zip_dict(schema.feature, obj): + for k, (sub_schema, sub_objs) in zip_dict(schema.feature, obj): list_dict[k] = [encode_nested_example(sub_schema, o) for o in sub_objs] return list_dict # schema.feature is not a dict @@ -1030,9 +1029,7 @@ def decode_nested_example(schema, obj): """ # Nested structures: we allow dict, list/tuples, sequences if isinstance(schema, dict): - return { - k: decode_nested_example(sub_schema, sub_obj) for k, (sub_schema, sub_obj) in utils.zip_dict(schema, obj) - } + return {k: decode_nested_example(sub_schema, sub_obj) for k, (sub_schema, sub_obj) in zip_dict(schema, obj)} elif isinstance(schema, (list, tuple)): sub_schema = schema[0] if obj is None: @@ -1328,7 +1325,7 @@ def decode_example(self, example: dict): column_name: decode_nested_example(feature, value) if self._column_requires_decoding[column_name] else value - for column_name, (feature, value) in utils.zip_dict( + for column_name, (feature, value) in zip_dict( {key: value for key, value in self.items() if key in example}, example ) } diff --git a/src/datasets/fingerprint.py b/src/datasets/fingerprint.py index 312d3526191..b02d703e215 100644 --- a/src/datasets/fingerprint.py +++ b/src/datasets/fingerprint.py @@ -14,9 +14,8 @@ import pyarrow as pa import xxhash -from datasets.table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table - from .info import DatasetInfo +from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table from .utils.logging import get_logger from .utils.py_utils import dumps diff --git a/src/datasets/formatting/__init__.py b/src/datasets/formatting/__init__.py index 1a737919025..d9b0f44f950 100644 --- a/src/datasets/formatting/__init__.py +++ b/src/datasets/formatting/__init__.py @@ -18,7 +18,7 @@ from typing import Dict, List, Optional from .. import config -from ..utils.logging import get_logger +from ..utils import logging from .formatting import ( ArrowFormatter, CustomFormatter, @@ -31,7 +31,7 @@ ) -logger = get_logger(__name__) +logger = logging.get_logger(__name__) _FORMAT_TYPES: Dict[Optional[str], type] = {} _FORMAT_TYPES_ALIASES: Dict[Optional[str], str] = {} diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index c0899f335e7..f857a53a992 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -21,9 +21,9 @@ import pandas as pd import pyarrow as pa -from ..features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper +from ..features.features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper from ..table import Table -from ..utils import no_op_if_value_is_null +from ..utils.py_utils import no_op_if_value_is_null T = TypeVar("T") diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index 43f3fc67c12..a2a33a9db92 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -104,7 +104,7 @@ def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> in written = 0 if self.num_proc is None or self.num_proc == 1: - for offset in utils.tqdm( + for offset in utils.tqdm_utils.tqdm( range(0, len(self.dataset), self.batch_size), unit="ba", disable=not utils.is_progress_bar_enabled(), @@ -116,7 +116,7 @@ def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> in else: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: - for csv_str in utils.tqdm( + for csv_str in utils.tqdm_utils.tqdm( pool.imap( self._batch_csv, [(offset, header, to_csv_kwargs) for offset in range(0, num_rows, batch_size)], diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index 6b31d45efc3..83283d01a2f 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -126,7 +126,7 @@ def _write( written = 0 if self.num_proc is None or self.num_proc == 1: - for offset in utils.tqdm( + for offset in utils.tqdm_utils.tqdm( range(0, len(self.dataset), self.batch_size), unit="ba", disable=not utils.is_progress_bar_enabled(), @@ -137,7 +137,7 @@ def _write( else: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: - for json_str in utils.tqdm( + for json_str in utils.tqdm_utils.tqdm( pool.imap( self._batch_json, [(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)], diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index b255f9d7c79..c382b193b10 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -8,7 +8,8 @@ import pyarrow as pa from .arrow_dataset import DatasetInfoMixin -from .features import Features, FeatureType +from .features import Features +from .features.features import FeatureType from .formatting import PythonFormatter from .info import DatasetInfo from .splits import NamedSplit diff --git a/src/datasets/load.py b/src/datasets/load.py index 97bed34a699..41edc28371b 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -48,7 +48,7 @@ from .info import DatasetInfo, DatasetInfosDict from .iterable_dataset import IterableDataset from .metric import Metric -from .packaged_modules import _EXTENSION_TO_MODULE, _PACKAGED_DATASETS_MODULES, hash_python_lines +from .packaged_modules import _EXTENSION_TO_MODULE, _PACKAGED_DATASETS_MODULES, _hash_python_lines from .splits import Split from .streaming import extend_module_for_streaming from .tasks import TaskTemplate @@ -139,7 +139,7 @@ def files_to_hash(file_paths: List[str]) -> str: for file_path in to_use_files: with open(file_path, encoding="utf-8") as f: lines.extend(f.readlines()) - return hash_python_lines(lines) + return _hash_python_lines(lines) def convert_github_url(url_path: str) -> Tuple[str, Optional[str]]: diff --git a/src/datasets/metric.py b/src/datasets/metric.py index fba77774881..60f139d0824 100644 --- a/src/datasets/metric.py +++ b/src/datasets/metric.py @@ -29,11 +29,11 @@ from .features import Features from .info import DatasetInfo, MetricInfo from .naming import camelcase_to_snakecase -from .utils import copyfunc, temp_seed from .utils.download_manager import DownloadManager from .utils.file_utils import DownloadConfig from .utils.filelock import BaseFileLock, FileLock, Timeout from .utils.logging import get_logger +from .utils.py_utils import copyfunc, temp_seed logger = get_logger(__name__) diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 4ec80488564..77fe7f2dcaa 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -11,7 +11,7 @@ from .text import text -def hash_python_lines(lines: List[str]) -> str: +def _hash_python_lines(lines: List[str]) -> str: filtered_lines = [] for line in lines: line = re.sub(r"#.*", "", line) # remove comments @@ -26,12 +26,12 @@ def hash_python_lines(lines: List[str]) -> str: # get importable module names and hash for caching _PACKAGED_DATASETS_MODULES = { - "csv": (csv.__name__, hash_python_lines(inspect.getsource(csv).splitlines())), - "json": (json.__name__, hash_python_lines(inspect.getsource(json).splitlines())), - "pandas": (pandas.__name__, hash_python_lines(inspect.getsource(pandas).splitlines())), - "parquet": (parquet.__name__, hash_python_lines(inspect.getsource(parquet).splitlines())), - "text": (text.__name__, hash_python_lines(inspect.getsource(text).splitlines())), - "imagefolder": (imagefolder.__name__, hash_python_lines(inspect.getsource(imagefolder).splitlines())), + "csv": (csv.__name__, _hash_python_lines(inspect.getsource(csv).splitlines())), + "json": (json.__name__, _hash_python_lines(inspect.getsource(json).splitlines())), + "pandas": (pandas.__name__, _hash_python_lines(inspect.getsource(pandas).splitlines())), + "parquet": (parquet.__name__, _hash_python_lines(inspect.getsource(parquet).splitlines())), + "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), + "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), } _EXTENSION_TO_MODULE = { diff --git a/src/datasets/search.py b/src/datasets/search.py index c970fb748f3..f9c3e442f05 100644 --- a/src/datasets/search.py +++ b/src/datasets/search.py @@ -150,7 +150,9 @@ def add_documents(self, documents: Union[List[str], "Dataset"], column: Optional index_config = self.es_index_config self.es_client.indices.create(index=index_name, body=index_config) number_of_docs = len(documents) - progress = utils.tqdm(unit="docs", total=number_of_docs, disable=not utils.is_progress_bar_enabled()) + progress = utils.tqdm_utils.tqdm( + unit="docs", total=number_of_docs, disable=not utils.is_progress_bar_enabled() + ) successes = 0 def passage_generator(): @@ -293,7 +295,9 @@ def add_vectors( # Add vectors logger.info(f"Adding {len(vectors)} vectors to the faiss index") - for i in utils.tqdm(range(0, len(vectors), batch_size), disable=not utils.is_progress_bar_enabled()): + for i in utils.tqdm_utils.tqdm( + range(0, len(vectors), batch_size), disable=not utils.is_progress_bar_enabled() + ): vecs = vectors[i : i + batch_size] if column is None else vectors[i : i + batch_size][column] self.faiss_index.add(vecs) diff --git a/src/datasets/table.py b/src/datasets/table.py index fc16abef42e..18dcf7ee8d3 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: - from .features import Features, FeatureType + from .features.features import Features, FeatureType logger = get_logger(__name__) @@ -1762,7 +1762,7 @@ def cast_array_to_feature(array: pa.Array, feature: "FeatureType", allow_number_ Returns: array (:obj:`pyarrow.Array`): the casted array """ - from .features import Sequence, get_nested_type + from .features.features import Sequence, get_nested_type _c = partial(cast_array_to_feature, allow_number_to_str=allow_number_to_str) diff --git a/src/datasets/utils/__init__.py b/src/datasets/utils/__init__.py index 195620b7c46..b24a68a91e4 100644 --- a/src/datasets/utils/__init__.py +++ b/src/datasets/utils/__init__.py @@ -16,26 +16,16 @@ # Lint as: python3 """Util import.""" -from . import logging -from .download_manager import DownloadManager, DownloadMode, GenerateMode -from .file_utils import DownloadConfig, cached_path, hf_bucket_url, is_remote_url, relative_to_absolute_path, temp_seed -from .mock_download_manager import MockDownloadManager -from .py_utils import ( - NonMutableDict, - classproperty, - copyfunc, - dumps, - first_non_null_value, - flatten_nest_dict, - has_sufficient_disk_space, - map_nested, - memoize, - no_op_if_value_is_null, - size_str, - temporary_assignment, - unique_values, - zip_dict, - zip_nested, -) -from .tqdm_utils import disable_progress_bar, is_progress_bar_enabled, set_progress_bar_enabled, tqdm +__all__ = [ + "DownloadConfig", + "DownloadManager", + "DownloadMode", + "disable_progress_bar", + "is_progress_bar_enabled", + "set_progress_bar_enabled", + "Version", +] + +from .download_manager import DownloadConfig, DownloadManager, DownloadMode +from .tqdm_utils import disable_progress_bar, is_progress_bar_enabled, set_progress_bar_enabled from .version import Version diff --git a/src/datasets/utils/download_manager.py b/src/datasets/utils/download_manager.py index 00137488163..66a2797fb60 100644 --- a/src/datasets/utils/download_manager.py +++ b/src/datasets/utils/download_manager.py @@ -184,7 +184,7 @@ def ship_files_with_pipeline(self, downloaded_path_or_paths, pipeline): """ Ship the files using Beam FileSystems to the pipeline temp dir. """ - from datasets.utils.beam_utils import upload_local_to_remote + from .beam_utils import upload_local_to_remote remote_dir = pipeline._options.get_all_options().get("temp_location") if remote_dir is None: diff --git a/src/datasets/utils/extract.py b/src/datasets/utils/extract.py index fdee06f8798..bf6f5963033 100644 --- a/src/datasets/utils/extract.py +++ b/src/datasets/utils/extract.py @@ -7,8 +7,8 @@ from zipfile import ZipFile from zipfile import is_zipfile as _is_zipfile -from datasets import config -from datasets.utils.filelock import FileLock +from .. import config +from .filelock import FileLock class ExtractManager: @@ -19,7 +19,7 @@ def __init__(self, cache_dir=None): self.extractor = Extractor def _get_output_path(self, path): - from datasets.utils.file_utils import hash_url_to_filename + from .file_utils import hash_url_to_filename # Path where we extract compressed archives # We extract in the cache dir, and get the extracted path name by hashing the original path" diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 0750bb4e967..ed3c3615895 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -23,7 +23,6 @@ from typing import Dict, Optional, TypeVar, Union from urllib.parse import urljoin, urlparse -import numpy as np import requests from .. import __version__, config, utils @@ -58,60 +57,6 @@ def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str: return hf_modules_cache -@contextmanager -def temp_seed(seed: int, set_pytorch=False, set_tensorflow=False): - """Temporarily set the random seed. This works for python numpy, pytorch and tensorflow.""" - np_state = np.random.get_state() - np.random.seed(seed) - - if set_pytorch and config.TORCH_AVAILABLE: - import torch - - torch_state = torch.random.get_rng_state() - torch.random.manual_seed(seed) - - if torch.cuda.is_available(): - torch_cuda_states = torch.cuda.get_rng_state_all() - torch.cuda.manual_seed_all(seed) - - if set_tensorflow and config.TF_AVAILABLE: - import tensorflow as tf - from tensorflow.python import context as tfpycontext - - tf_state = tf.random.get_global_generator() - temp_gen = tf.random.Generator.from_seed(seed) - tf.random.set_global_generator(temp_gen) - - if not tf.executing_eagerly(): - raise ValueError("Setting random seed for TensorFlow is only available in eager mode") - - tf_context = tfpycontext.context() # eager mode context - tf_seed = tf_context._seed - tf_rng_initialized = hasattr(tf_context, "_rng") - if tf_rng_initialized: - tf_rng = tf_context._rng - tf_context._set_global_seed(seed) - - try: - yield - finally: - np.random.set_state(np_state) - - if set_pytorch and config.TORCH_AVAILABLE: - torch.random.set_rng_state(torch_state) - if torch.cuda.is_available(): - torch.cuda.set_rng_state_all(torch_cuda_states) - - if set_tensorflow and config.TF_AVAILABLE: - tf.random.set_global_generator(tf_state) - - tf_context._seed = tf_seed - if tf_rng_initialized: - tf_context._rng = tf_rng - else: - delattr(tf_context, "_rng") - - def is_remote_url(url_or_filename: str) -> bool: parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https", "s3", "gs", "hdfs", "ftp") @@ -447,7 +392,7 @@ def http_get( return content_length = response.headers.get("Content-Length") total = resume_size + int(content_length) if content_length is not None else None - with utils.tqdm( + with utils.tqdm_utils.tqdm( unit="B", unit_scale=True, total=total, diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 54880e44a94..998dee59cac 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -25,6 +25,7 @@ import re import sys import types +from contextlib import contextmanager from io import BytesIO as StringIO from multiprocessing import Pool, RLock from shutil import disk_usage @@ -35,7 +36,7 @@ import numpy as np from tqdm.auto import tqdm -from .. import utils +from .. import config, utils from . import logging @@ -119,6 +120,60 @@ def temporary_assignment(obj, attr, value): setattr(obj, attr, original) +@contextmanager +def temp_seed(seed: int, set_pytorch=False, set_tensorflow=False): + """Temporarily set the random seed. This works for python numpy, pytorch and tensorflow.""" + np_state = np.random.get_state() + np.random.seed(seed) + + if set_pytorch and config.TORCH_AVAILABLE: + import torch + + torch_state = torch.random.get_rng_state() + torch.random.manual_seed(seed) + + if torch.cuda.is_available(): + torch_cuda_states = torch.cuda.get_rng_state_all() + torch.cuda.manual_seed_all(seed) + + if set_tensorflow and config.TF_AVAILABLE: + import tensorflow as tf + from tensorflow.python import context as tfpycontext + + tf_state = tf.random.get_global_generator() + temp_gen = tf.random.Generator.from_seed(seed) + tf.random.set_global_generator(temp_gen) + + if not tf.executing_eagerly(): + raise ValueError("Setting random seed for TensorFlow is only available in eager mode") + + tf_context = tfpycontext.context() # eager mode context + tf_seed = tf_context._seed + tf_rng_initialized = hasattr(tf_context, "_rng") + if tf_rng_initialized: + tf_rng = tf_context._rng + tf_context._set_global_seed(seed) + + try: + yield + finally: + np.random.set_state(np_state) + + if set_pytorch and config.TORCH_AVAILABLE: + torch.random.set_rng_state(torch_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state_all(torch_cuda_states) + + if set_tensorflow and config.TF_AVAILABLE: + tf.random.set_global_generator(tf_state) + + tf_context._seed = tf_seed + if tf_rng_initialized: + tf_context._rng = tf_rng + else: + delattr(tf_context, "_rng") + + def unique_values(values): """Iterate over iterable and return only unique values in order.""" seen = set() @@ -206,7 +261,7 @@ def _single_map_nested(args): # Loop over single examples or batches and write to buffer/file if examples are to be updated pbar_iterable = data_struct.items() if isinstance(data_struct, dict) else data_struct pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc - pbar = utils.tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) + pbar = utils.tqdm_utils.tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) if isinstance(data_struct, dict): return {k: _single_map_nested((function, v, types, None, True, None)) for k, v in pbar} @@ -258,7 +313,7 @@ def map_nested( if num_proc <= 1 or len(iterable) <= num_proc: mapped = [ _single_map_nested((function, obj, types, None, True, None)) - for obj in utils.tqdm(iterable, disable=disable_tqdm, desc=desc) + for obj in utils.tqdm_utils.tqdm(iterable, disable=disable_tqdm, desc=desc) ] else: split_kwds = [] # We organize the splits ourselve (contiguous splits) @@ -299,34 +354,6 @@ def map_nested( return np.array(mapped) -def zip_nested(arg0, *args, **kwargs): - """Zip data struct together and return a data struct with the same shape.""" - # Python 2 do not support kwargs only arguments - dict_only = kwargs.pop("dict_only", False) - assert not kwargs - - # Could add support for more exotic data_struct, like OrderedDict - if isinstance(arg0, dict): - return {k: zip_nested(*a, dict_only=dict_only) for k, a in zip_dict(arg0, *args)} - elif not dict_only: - if isinstance(arg0, list): - return [zip_nested(*a, dict_only=dict_only) for a in zip(arg0, *args)] - # Singleton - return (arg0,) + args - - -def flatten_nest_dict(d): - """Return the dict with all nested keys flattened joined with '/'.""" - # Use NonMutableDict to ensure there is no collision between features keys - flat_dict = NonMutableDict() - for k, v in d.items(): - if isinstance(v, dict): - flat_dict.update({f"{k}/{k2}": v2 for k2, v2 in flatten_nest_dict(v).items()}) - else: - flat_dict[k] = v - return flat_dict - - class NestedDataStructure: def __init__(self, data=None): self.data = data if data is not None else [] diff --git a/src/datasets/utils/tqdm_utils.py b/src/datasets/utils/tqdm_utils.py index ddac8fcd24e..d5dde391303 100644 --- a/src/datasets/utils/tqdm_utils.py +++ b/src/datasets/utils/tqdm_utils.py @@ -17,7 +17,7 @@ """ from tqdm import auto as tqdm_lib -from datasets.utils.deprecation_utils import deprecated +from .deprecation_utils import deprecated class EmptyTqdm: diff --git a/tests/features/test_array_xd.py b/tests/features/test_array_xd.py index 797cbad38c7..759b31dcecf 100644 --- a/tests/features/test_array_xd.py +++ b/tests/features/test_array_xd.py @@ -10,7 +10,8 @@ import datasets from datasets.arrow_writer import ArrowWriter -from datasets.features import Array2D, Array3D, Array3DExtensionType, Array4D, Array5D, Value, _ArrayXD +from datasets.features import Array2D, Array3D, Array4D, Array5D, Value +from datasets.features.features import Array3DExtensionType, PandasArrayExtensionDtype, _ArrayXD from datasets.formatting.formatting import NumpyArrowExtractor, SimpleArrowExtractor @@ -305,7 +306,7 @@ def test_table_to_pandas(dtype, dummy_value): features = datasets.Features({"foo": datasets.Array2D(dtype=dtype, shape=(2, 2))}) dataset = datasets.Dataset.from_dict({"foo": [[[dummy_value] * 2] * 2]}, features=features) df = dataset._data.to_pandas() - assert type(df.foo.dtype) == datasets.features.PandasArrayExtensionDtype + assert type(df.foo.dtype) == PandasArrayExtensionDtype arr = df.foo.to_numpy() np.testing.assert_equal(arr, np.array([[[dummy_value] * 2] * 2], dtype=np.dtype(dtype))) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 66295585f55..a6145366af3 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -8,11 +8,8 @@ import pytest from datasets.arrow_dataset import Dataset -from datasets.features import ( - ClassLabel, - Features, - Sequence, - Value, +from datasets.features import ClassLabel, Features, Sequence, Value +from datasets.features.features import ( _arrow_to_datasets_dtype, _cast_to_python_objects, cast_to_python_objects, diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 8d02d8ad363..7e018de0098 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -17,7 +17,7 @@ from absl.testing import parameterized import datasets.arrow_dataset -from datasets import concatenate_datasets, interleave_datasets, load_from_disk, temp_seed +from datasets import concatenate_datasets, interleave_datasets, load_from_disk from datasets.arrow_dataset import Dataset, transmit_format, update_metadata_with_features from datasets.dataset_dict import DatasetDict from datasets.features import Array2D, Array3D, ClassLabel, Features, Sequence, Value @@ -33,6 +33,7 @@ TextClassification, ) from datasets.utils.logging import WARNING +from datasets.utils.py_utils import temp_seed from .conftest import s3_test_bucket_name from .utils import ( diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index 3fda2be4fe3..d3be0a03958 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -8,7 +8,8 @@ import pytest from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, TypedSequence -from datasets.features import Array2D, Array2DExtensionType, ClassLabel, Features, Value +from datasets.features import Array2D, ClassLabel, Features, Value +from datasets.features.features import Array2DExtensionType from datasets.keyhash import DuplicatedKeysError, InvalidKeyError diff --git a/tests/test_caching.py b/tests/test_caching.py index abe3045a1e8..1aaf20cde49 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -57,17 +57,17 @@ def encode(x): # TODO: add hash consistency tests across sessions tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - hash1 = md5(datasets.utils.dumps(tokenizer)).hexdigest() - hash1_lambda = md5(datasets.utils.dumps(lambda x: tokenizer(x))).hexdigest() - hash1_encode = md5(datasets.utils.dumps(encode)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() + hash1_lambda = md5(datasets.utils.py_utils.dumps(lambda x: tokenizer(x))).hexdigest() + hash1_encode = md5(datasets.utils.py_utils.dumps(encode)).hexdigest() tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") - hash2 = md5(datasets.utils.dumps(tokenizer)).hexdigest() - hash2_lambda = md5(datasets.utils.dumps(lambda x: tokenizer(x))).hexdigest() - hash2_encode = md5(datasets.utils.dumps(encode)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() + hash2_lambda = md5(datasets.utils.py_utils.dumps(lambda x: tokenizer(x))).hexdigest() + hash2_encode = md5(datasets.utils.py_utils.dumps(encode)).hexdigest() tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - hash3 = md5(datasets.utils.dumps(tokenizer)).hexdigest() - hash3_lambda = md5(datasets.utils.dumps(lambda x: tokenizer(x))).hexdigest() - hash3_encode = md5(datasets.utils.dumps(encode)).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() + hash3_lambda = md5(datasets.utils.py_utils.dumps(lambda x: tokenizer(x))).hexdigest() + hash3_encode = md5(datasets.utils.py_utils.dumps(encode)).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) self.assertEqual(hash1_lambda, hash3_lambda) @@ -80,9 +80,9 @@ def test_hash_tokenizer_with_cache(self): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") - hash1 = md5(datasets.utils.dumps(tokenizer)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() tokenizer("Hello world !") # call once to change the tokenizer's cache - hash2 = md5(datasets.utils.dumps(tokenizer)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(tokenizer)).hexdigest() self.assertEqual(hash1, hash2) @require_regex @@ -90,11 +90,11 @@ def test_hash_regex(self): import regex pat = regex.Regex("foo") - hash1 = md5(datasets.utils.dumps(pat)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(pat)).hexdigest() pat = regex.Regex("bar") - hash2 = md5(datasets.utils.dumps(pat)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(pat)).hexdigest() pat = regex.Regex("foo") - hash3 = md5(datasets.utils.dumps(pat)).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(pat)).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -105,11 +105,11 @@ def func(): return foo foo = [0] - hash1 = md5(datasets.utils.dumps(func)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() foo = [1] - hash2 = md5(datasets.utils.dumps(func)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() foo = [0] - hash3 = md5(datasets.utils.dumps(func)).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -117,27 +117,27 @@ def test_dump_ignores_line_definition_of_function(self): def func(): pass - hash1 = md5(datasets.utils.dumps(func)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() def func(): pass - hash2 = md5(datasets.utils.dumps(func)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() self.assertEqual(hash1, hash2) def test_recurse_dump_for_class(self): - hash1 = md5(datasets.utils.dumps(Foo([0]))).hexdigest() - hash2 = md5(datasets.utils.dumps(Foo([1]))).hexdigest() - hash3 = md5(datasets.utils.dumps(Foo([0]))).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(Foo([0]))).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(Foo([1]))).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(Foo([0]))).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) def test_recurse_dump_for_method(self): - hash1 = md5(datasets.utils.dumps(Foo([0]).__call__)).hexdigest() - hash2 = md5(datasets.utils.dumps(Foo([1]).__call__)).hexdigest() - hash3 = md5(datasets.utils.dumps(Foo([0]).__call__)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(Foo([0]).__call__)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(Foo([1]).__call__)).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(Foo([0]).__call__)).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -195,11 +195,11 @@ def func(): return FunctionType(code, func.__globals__, func.__name__, func.__defaults__, func.__closure__) co_filename, returned_obj = "", [0] - hash1 = md5(datasets.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() co_filename, returned_obj = "", [1] - hash2 = md5(datasets.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() co_filename, returned_obj = "", [0] - hash3 = md5(datasets.utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(create_ipython_func(co_filename, returned_obj))).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) @@ -218,10 +218,10 @@ def globalvars_mock2_side_effect(func, *args, **kwargs): return {"bar": bar, "foo": foo} with patch("dill.detect.globalvars", side_effect=globalvars_mock1_side_effect) as globalvars_mock1: - hash1 = md5(datasets.utils.dumps(func)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() self.assertGreater(globalvars_mock1.call_count, 0) with patch("dill.detect.globalvars", side_effect=globalvars_mock2_side_effect) as globalvars_mock2: - hash2 = md5(datasets.utils.dumps(func)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(func)).hexdigest() self.assertGreater(globalvars_mock2.call_count, 0) self.assertEqual(hash1, hash2) @@ -232,11 +232,11 @@ def test_dump_type_hint(self): t1 = Union[str, None] # this type is not picklable in python 3.6 # let's check that we can pickle it anyway using our pickler, even in 3.6 - hash1 = md5(datasets.utils.dumps(t1)).hexdigest() + hash1 = md5(datasets.utils.py_utils.dumps(t1)).hexdigest() t2 = Union[str] # this type is picklable in python 3.6 - hash2 = md5(datasets.utils.dumps(t2)).hexdigest() + hash2 = md5(datasets.utils.py_utils.dumps(t2)).hexdigest() t3 = Union[str, None] - hash3 = md5(datasets.utils.dumps(t3)).hexdigest() + hash3 = md5(datasets.utils.py_utils.dumps(t3)).hexdigest() self.assertEqual(hash1, hash3) self.assertNotEqual(hash1, hash2) diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index d623e64e8bb..928d9b45d42 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -1,9 +1,7 @@ import os from pathlib import Path -from unittest import TestCase from unittest.mock import patch -import numpy as np import pytest import zstandard as zstd @@ -15,11 +13,8 @@ ftp_head, http_get, http_head, - temp_seed, ) -from .utils import require_tf, require_torch - FILE_CONTENT = """\ Text data. @@ -35,58 +30,6 @@ def zstd_path(tmp_path_factory): return path -class TempSeedTest(TestCase): - @require_tf - def test_tensorflow(self): - import tensorflow as tf - from tensorflow.keras import layers - - def gen_random_output(): - model = layers.Dense(2) - x = tf.random.uniform((1, 3)) - return model(x).numpy() - - with temp_seed(42, set_tensorflow=True): - out1 = gen_random_output() - with temp_seed(42, set_tensorflow=True): - out2 = gen_random_output() - out3 = gen_random_output() - - np.testing.assert_equal(out1, out2) - self.assertGreater(np.abs(out1 - out3).sum(), 0) - - @require_torch - def test_torch(self): - import torch - - def gen_random_output(): - model = torch.nn.Linear(3, 2) - x = torch.rand(1, 3) - return model(x).detach().numpy() - - with temp_seed(42, set_pytorch=True): - out1 = gen_random_output() - with temp_seed(42, set_pytorch=True): - out2 = gen_random_output() - out3 = gen_random_output() - - np.testing.assert_equal(out1, out2) - self.assertGreater(np.abs(out1 - out3).sum(), 0) - - def test_numpy(self): - def gen_random_output(): - return np.random.rand(1, 3) - - with temp_seed(42): - out1 = gen_random_output() - with temp_seed(42): - out2 = gen_random_output() - out3 = gen_random_output() - - np.testing.assert_equal(out1, out2) - self.assertGreater(np.abs(out1 - out3).sum(), 0) - - @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} diff --git a/tests/test_hf_gcp.py b/tests/test_hf_gcp.py index ae33fba6023..2dd82d08262 100644 --- a/tests/test_hf_gcp.py +++ b/tests/test_hf_gcp.py @@ -8,7 +8,7 @@ from datasets.arrow_reader import HF_GCP_BASE_URL from datasets.builder import DatasetBuilder from datasets.load import dataset_module_factory, import_main_class -from datasets.utils import cached_path +from datasets.utils.file_utils import cached_path DATASETS_ON_HF_GCP = [ diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index f01188552f5..bb84d7c8f78 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -3,14 +3,9 @@ import numpy as np import pytest -from datasets.utils.py_utils import ( - NestedDataStructure, - flatten_nest_dict, - map_nested, - temporary_assignment, - zip_dict, - zip_nested, -) +from datasets.utils.py_utils import NestedDataStructure, map_nested, temp_seed, temporary_assignment, zip_dict + +from .utils import require_tf, require_torch def np_sum(x): # picklable for multiprocessing @@ -22,17 +17,6 @@ def add_one(i): # picklable for multiprocessing class PyUtilsTest(TestCase): - def test_flatten_nest_dict(self): - d1 = {} - d2 = {"a": 1, "b": 2} - d3 = {"a": {"1": 1, "2": 2}, "b": 3} - expected_flatten_d1 = {} - expected_flatten_d2 = {"a": 1, "b": 2} - expected_flatten_d3 = {"a/1": 1, "a/2": 2, "b": 3} - self.assertDictEqual(flatten_nest_dict(d1), expected_flatten_d1) - self.assertDictEqual(flatten_nest_dict(d2), expected_flatten_d2) - self.assertDictEqual(flatten_nest_dict(d3), expected_flatten_d3) - def test_map_nested(self): s1 = {} s2 = [] @@ -96,12 +80,6 @@ def test_zip_dict(self): expected_zip_dict_result = sorted([("a", (1, 3, 5)), ("b", (2, 4, 6))]) self.assertEqual(sorted(list(zip_dict(d1, d2, d3))), expected_zip_dict_result) - def test_zip_nested(self): - d1 = {"a": {"1": 1}, "b": 2} - d2 = {"a": {"1": 3}, "b": 4} - expected_zip_nested_result = {"a": {"1": (1, 3)}, "b": (2, 4)} - self.assertDictEqual(zip_nested(d1, d2), expected_zip_nested_result) - def test_temporary_assignment(self): class Foo: my_attr = "bar" @@ -113,6 +91,58 @@ class Foo: self.assertEqual(foo.my_attr, "bar") +class TempSeedTest(TestCase): + @require_tf + def test_tensorflow(self): + import tensorflow as tf + from tensorflow.keras import layers + + def gen_random_output(): + model = layers.Dense(2) + x = tf.random.uniform((1, 3)) + return model(x).numpy() + + with temp_seed(42, set_tensorflow=True): + out1 = gen_random_output() + with temp_seed(42, set_tensorflow=True): + out2 = gen_random_output() + out3 = gen_random_output() + + np.testing.assert_equal(out1, out2) + self.assertGreater(np.abs(out1 - out3).sum(), 0) + + @require_torch + def test_torch(self): + import torch + + def gen_random_output(): + model = torch.nn.Linear(3, 2) + x = torch.rand(1, 3) + return model(x).detach().numpy() + + with temp_seed(42, set_pytorch=True): + out1 = gen_random_output() + with temp_seed(42, set_pytorch=True): + out2 = gen_random_output() + out3 = gen_random_output() + + np.testing.assert_equal(out1, out2) + self.assertGreater(np.abs(out1 - out3).sum(), 0) + + def test_numpy(self): + def gen_random_output(): + return np.random.rand(1, 3) + + with temp_seed(42): + out1 = gen_random_output() + with temp_seed(42): + out2 = gen_random_output() + out3 = gen_random_output() + + np.testing.assert_equal(out1, out2) + self.assertGreater(np.abs(out1 - out3).sum(), 0) + + @pytest.mark.parametrize("input_data", [{}]) def test_nested_data_structure_data(input_data): output_data = NestedDataStructure(input_data).data